1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23
24 #define DEBUG_TYPE "pdl-bytecode"
25
26 using namespace mlir;
27 using namespace mlir::detail;
28
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34 ByteCodeAddr rewriterAddr) {
35 SmallVector<StringRef, 8> generatedOps;
36 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37 generatedOps =
38 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39
40 PatternBenefit benefit = matchOp.benefit();
41 MLIRContext *ctx = matchOp.getContext();
42
43 // Check to see if this is pattern matches a specific operation type.
44 if (Optional<StringRef> rootKind = matchOp.rootKind())
45 return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46 ctx);
47 return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48 MatchAnyOpTypeTag());
49 }
50
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59 PatternBenefit benefit) {
60 currentPatternBenefits[patternIndex] = benefit;
61 }
62
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66
67 namespace {
68 enum OpCode : ByteCodeField {
69 /// Apply an externally registered constraint.
70 ApplyConstraint,
71 /// Apply an externally registered rewrite.
72 ApplyRewrite,
73 /// Check if two generic values are equal.
74 AreEqual,
75 /// Unconditional branch.
76 Branch,
77 /// Compare the operand count of an operation with a constant.
78 CheckOperandCount,
79 /// Compare the name of an operation with a constant.
80 CheckOperationName,
81 /// Compare the result count of an operation with a constant.
82 CheckResultCount,
83 /// Invoke a native creation method.
84 CreateNative,
85 /// Create an operation.
86 CreateOperation,
87 /// Erase an operation.
88 EraseOp,
89 /// Terminate a matcher or rewrite sequence.
90 Finalize,
91 /// Get a specific attribute of an operation.
92 GetAttribute,
93 /// Get the type of an attribute.
94 GetAttributeType,
95 /// Get the defining operation of a value.
96 GetDefiningOp,
97 /// Get a specific operand of an operation.
98 GetOperand0,
99 GetOperand1,
100 GetOperand2,
101 GetOperand3,
102 GetOperandN,
103 /// Get a specific result of an operation.
104 GetResult0,
105 GetResult1,
106 GetResult2,
107 GetResult3,
108 GetResultN,
109 /// Get the type of a value.
110 GetValueType,
111 /// Check if a generic value is not null.
112 IsNotNull,
113 /// Record a successful pattern match.
114 RecordMatch,
115 /// Replace an operation.
116 ReplaceOp,
117 /// Compare an attribute with a set of constants.
118 SwitchAttribute,
119 /// Compare the operand count of an operation with a set of constants.
120 SwitchOperandCount,
121 /// Compare the name of an operation with a set of constants.
122 SwitchOperationName,
123 /// Compare the result count of an operation with a set of constants.
124 SwitchResultCount,
125 /// Compare a type with a set of constants.
126 SwitchType,
127 };
128
129 enum class PDLValueKind { Attribute, Operation, Type, Value };
130 } // end anonymous namespace
131
132 //===----------------------------------------------------------------------===//
133 // ByteCode Generation
134 //===----------------------------------------------------------------------===//
135
136 //===----------------------------------------------------------------------===//
137 // Generator
138
139 namespace {
140 struct ByteCodeWriter;
141
142 /// This class represents the main generator for the pattern bytecode.
143 class Generator {
144 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLCreateFunction> & createFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)145 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146 SmallVectorImpl<ByteCodeField> &matcherByteCode,
147 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148 SmallVectorImpl<PDLByteCodePattern> &patterns,
149 ByteCodeField &maxValueMemoryIndex,
150 llvm::StringMap<PDLConstraintFunction> &constraintFns,
151 llvm::StringMap<PDLCreateFunction> &createFns,
152 llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154 rewriterByteCode(rewriterByteCode), patterns(patterns),
155 maxValueMemoryIndex(maxValueMemoryIndex) {
156 for (auto it : llvm::enumerate(constraintFns))
157 constraintToMemIndex.try_emplace(it.value().first(), it.index());
158 for (auto it : llvm::enumerate(createFns))
159 nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160 for (auto it : llvm::enumerate(rewriteFns))
161 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162 }
163
164 /// Generate the bytecode for the given PDL interpreter module.
165 void generate(ModuleOp module);
166
167 /// Return the memory index to use for the given value.
getMemIndex(Value value)168 ByteCodeField &getMemIndex(Value value) {
169 assert(valueToMemIndex.count(value) &&
170 "expected memory index to be assigned");
171 return valueToMemIndex[value];
172 }
173
174 /// Return an index to use when referring to the given data that is uniqued in
175 /// the MLIR context.
176 template <typename T>
177 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)178 getMemIndex(T val) {
179 const void *opaqueVal = val.getAsOpaquePointer();
180
181 // Get or insert a reference to this value.
182 auto it = uniquedDataToMemIndex.try_emplace(
183 opaqueVal, maxValueMemoryIndex + uniquedData.size());
184 if (it.second)
185 uniquedData.push_back(opaqueVal);
186 return it.first->second;
187 }
188
189 private:
190 /// Allocate memory indices for the results of operations within the matcher
191 /// and rewriters.
192 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193
194 /// Generate the bytecode for the given operation.
195 void generate(Operation *op, ByteCodeWriter &writer);
196 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206 void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217 void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226
227 /// Mapping from value to its corresponding memory index.
228 DenseMap<Value, ByteCodeField> valueToMemIndex;
229
230 /// Mapping from the name of an externally registered rewrite to its index in
231 /// the bytecode registry.
232 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233
234 /// Mapping from the name of an externally registered constraint to its index
235 /// in the bytecode registry.
236 llvm::StringMap<ByteCodeField> constraintToMemIndex;
237
238 /// Mapping from the name of an externally registered creation method to its
239 /// index in the bytecode registry.
240 llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241
242 /// Mapping from rewriter function name to the bytecode address of the
243 /// rewriter function in byte.
244 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245
246 /// Mapping from a uniqued storage object to its memory index within
247 /// `uniquedData`.
248 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249
250 /// The current MLIR context.
251 MLIRContext *ctx;
252
253 /// Data of the ByteCode class to be populated.
254 std::vector<const void *> &uniquedData;
255 SmallVectorImpl<ByteCodeField> &matcherByteCode;
256 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257 SmallVectorImpl<PDLByteCodePattern> &patterns;
258 ByteCodeField &maxValueMemoryIndex;
259 };
260
261 /// This class provides utilities for writing a bytecode stream.
262 struct ByteCodeWriter {
ByteCodeWriter__anon8ccc2f200211::ByteCodeWriter263 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264 : bytecode(bytecode), generator(generator) {}
265
266 /// Append a field to the bytecode.
append__anon8ccc2f200211::ByteCodeWriter267 void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon8ccc2f200211::ByteCodeWriter268 void append(OpCode opCode) { bytecode.push_back(opCode); }
269
270 /// Append an address to the bytecode.
append__anon8ccc2f200211::ByteCodeWriter271 void append(ByteCodeAddr field) {
272 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273 "unexpected ByteCode address size");
274
275 ByteCodeField fieldParts[2];
276 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277 bytecode.append({fieldParts[0], fieldParts[1]});
278 }
279
280 /// Append a successor range to the bytecode, the exact address will need to
281 /// be resolved later.
append__anon8ccc2f200211::ByteCodeWriter282 void append(SuccessorRange successors) {
283 // Add back references to the any successors so that the address can be
284 // resolved later.
285 for (Block *successor : successors) {
286 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287 append(ByteCodeAddr(0));
288 }
289 }
290
291 /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon8ccc2f200211::ByteCodeWriter292 void appendPDLValueList(OperandRange values) {
293 bytecode.push_back(values.size());
294 for (Value value : values) {
295 // Append the type of the value in addition to the value itself.
296 PDLValueKind kind =
297 TypeSwitch<Type, PDLValueKind>(value.getType())
298 .Case<pdl::AttributeType>(
299 [](Type) { return PDLValueKind::Attribute; })
300 .Case<pdl::OperationType>(
301 [](Type) { return PDLValueKind::Operation; })
302 .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303 .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304 bytecode.push_back(static_cast<ByteCodeField>(kind));
305 append(value);
306 }
307 }
308
309 /// Check if the given class `T` has an iterator type.
310 template <typename T, typename... Args>
311 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312
313 /// Append a value that will be stored in a memory slot and not inline within
314 /// the bytecode.
315 template <typename T>
316 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317 std::is_pointer<T>::value>
append__anon8ccc2f200211::ByteCodeWriter318 append(T value) {
319 bytecode.push_back(generator.getMemIndex(value));
320 }
321
322 /// Append a range of values.
323 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon8ccc2f200211::ByteCodeWriter325 append(T range) {
326 bytecode.push_back(llvm::size(range));
327 for (auto it : range)
328 append(it);
329 }
330
331 /// Append a variadic number of fields to the bytecode.
332 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon8ccc2f200211::ByteCodeWriter333 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334 append(field);
335 append(field2, fields...);
336 }
337
338 /// Successor references in the bytecode that have yet to be resolved.
339 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340
341 /// The underlying bytecode buffer.
342 SmallVectorImpl<ByteCodeField> &bytecode;
343
344 /// The main generator producing PDL.
345 Generator &generator;
346 };
347 } // end anonymous namespace
348
generate(ModuleOp module)349 void Generator::generate(ModuleOp module) {
350 FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353 pdl_interp::PDLInterpDialect::getRewriterModuleName());
354 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355
356 // Allocate memory indices for the results of operations within the matcher
357 // and rewriters.
358 allocateMemoryIndices(matcherFunc, rewriterModule);
359
360 // Generate code for the rewriter functions.
361 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364 for (Operation &op : rewriterFunc.getOps())
365 generate(&op, rewriterByteCodeWriter);
366 }
367 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368 "unexpected branches in rewriter function");
369
370 // Generate code for the matcher function.
371 DenseMap<Block *, ByteCodeAddr> blockToAddr;
372 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374 for (Block *block : rpot) {
375 // Keep track of where this block begins within the matcher function.
376 blockToAddr.try_emplace(block, matcherByteCode.size());
377 for (Operation &op : *block)
378 generate(&op, matcherByteCodeWriter);
379 }
380
381 // Resolve successor references in the matcher.
382 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383 ByteCodeAddr addr = blockToAddr[it.first];
384 for (unsigned offsetToFix : it.second)
385 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386 }
387 }
388
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)389 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390 ModuleOp rewriterModule) {
391 // Rewriters use simplistic allocation scheme that simply assigns an index to
392 // each result.
393 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394 ByteCodeField index = 0;
395 for (BlockArgument arg : rewriterFunc.getArguments())
396 valueToMemIndex.try_emplace(arg, index++);
397 rewriterFunc.getBody().walk([&](Operation *op) {
398 for (Value result : op->getResults())
399 valueToMemIndex.try_emplace(result, index++);
400 });
401 if (index > maxValueMemoryIndex)
402 maxValueMemoryIndex = index;
403 }
404
405 // The matcher function uses a more sophisticated numbering that tries to
406 // minimize the number of memory indices assigned. This is done by determining
407 // a live range of the values within the matcher, then the allocation is just
408 // finding the minimal number of overlapping live ranges. This is essentially
409 // a simplified form of register allocation where we don't necessarily have a
410 // limited number of registers, but we still want to minimize the number used.
411 DenseMap<Operation *, ByteCodeField> opToIndex;
412 matcherFunc.getBody().walk([&](Operation *op) {
413 opToIndex.insert(std::make_pair(op, opToIndex.size()));
414 });
415
416 // Liveness info for each of the defs within the matcher.
417 using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418 LivenessSet::Allocator allocator;
419 DenseMap<Value, LivenessSet> valueDefRanges;
420
421 // Assign the root operation being matched to slot 0.
422 BlockArgument rootOpArg = matcherFunc.getArgument(0);
423 valueToMemIndex[rootOpArg] = 0;
424
425 // Walk each of the blocks, computing the def interval that the value is used.
426 Liveness matcherLiveness(matcherFunc);
427 for (Block &block : matcherFunc.getBody()) {
428 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429 assert(info && "expected liveness info for block");
430 auto processValue = [&](Value value, Operation *firstUseOrDef) {
431 // We don't need to process the root op argument, this value is always
432 // assigned to the first memory slot.
433 if (value == rootOpArg)
434 return;
435
436 // Set indices for the range of this block that the value is used.
437 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438 defRangeIt->second.insert(
439 opToIndex[firstUseOrDef],
440 opToIndex[info->getEndOperation(value, firstUseOrDef)],
441 /*dummyValue*/ 0);
442 };
443
444 // Process the live-ins of this block.
445 for (Value liveIn : info->in())
446 processValue(liveIn, &block.front());
447
448 // Process any new defs within this block.
449 for (Operation &op : block)
450 for (Value result : op.getResults())
451 processValue(result, &op);
452 }
453
454 // Greedily allocate memory slots using the computed def live ranges.
455 std::vector<LivenessSet> allocatedIndices;
456 for (auto &defIt : valueDefRanges) {
457 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458 LivenessSet &defSet = defIt.second;
459
460 // Try to allocate to an existing index.
461 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462 LivenessSet &existingIndex = existingIndexIt.value();
463 llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464 defIt.second, existingIndex);
465 if (overlaps.valid())
466 continue;
467 // Union the range of the def within the existing index.
468 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469 existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470 memIndex = existingIndexIt.index() + 1;
471 }
472
473 // If no existing index could be used, add a new one.
474 if (memIndex == 0) {
475 allocatedIndices.emplace_back(allocator);
476 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477 allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478 memIndex = allocatedIndices.size();
479 }
480 }
481
482 // Update the max number of indices.
483 ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484 if (numMatcherIndices > maxValueMemoryIndex)
485 maxValueMemoryIndex = numMatcherIndices;
486 }
487
generate(Operation * op,ByteCodeWriter & writer)488 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489 TypeSwitch<Operation *>(op)
490 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494 pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495 pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496 pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497 pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499 pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500 pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503 pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505 [&](auto interpOp) { this->generate(interpOp, writer); })
506 .Default([](Operation *) {
507 llvm_unreachable("unknown `pdl_interp` operation");
508 });
509 }
510
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)511 void Generator::generate(pdl_interp::ApplyConstraintOp op,
512 ByteCodeWriter &writer) {
513 assert(constraintToMemIndex.count(op.name()) &&
514 "expected index for constraint function");
515 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516 op.constParamsAttr());
517 writer.appendPDLValueList(op.args());
518 writer.append(op.getSuccessors());
519 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)520 void Generator::generate(pdl_interp::ApplyRewriteOp op,
521 ByteCodeWriter &writer) {
522 assert(externalRewriterToMemIndex.count(op.name()) &&
523 "expected index for rewrite function");
524 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525 op.constParamsAttr(), op.root());
526 writer.appendPDLValueList(op.args());
527 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529 writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)534 void Generator::generate(pdl_interp::CheckAttributeOp op,
535 ByteCodeWriter &writer) {
536 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537 op.getSuccessors());
538 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)539 void Generator::generate(pdl_interp::CheckOperandCountOp op,
540 ByteCodeWriter &writer) {
541 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542 op.getSuccessors());
543 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)544 void Generator::generate(pdl_interp::CheckOperationNameOp op,
545 ByteCodeWriter &writer) {
546 writer.append(OpCode::CheckOperationName, op.operation(),
547 OperationName(op.name(), ctx), op.getSuccessors());
548 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)549 void Generator::generate(pdl_interp::CheckResultCountOp op,
550 ByteCodeWriter &writer) {
551 writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552 op.getSuccessors());
553 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)557 void Generator::generate(pdl_interp::CreateAttributeOp op,
558 ByteCodeWriter &writer) {
559 // Simply repoint the memory index of the result to the constant.
560 getMemIndex(op.attribute()) = getMemIndex(op.value());
561 }
generate(pdl_interp::CreateNativeOp op,ByteCodeWriter & writer)562 void Generator::generate(pdl_interp::CreateNativeOp op,
563 ByteCodeWriter &writer) {
564 assert(nativeCreateToMemIndex.count(op.name()) &&
565 "expected index for creation function");
566 writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567 op.result(), op.constParamsAttr());
568 writer.appendPDLValueList(op.args());
569 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)570 void Generator::generate(pdl_interp::CreateOperationOp op,
571 ByteCodeWriter &writer) {
572 writer.append(OpCode::CreateOperation, op.operation(),
573 OperationName(op.name(), ctx), op.operands());
574
575 // Add the attributes.
576 OperandRange attributes = op.attributes();
577 writer.append(static_cast<ByteCodeField>(attributes.size()));
578 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579 writer.append(
580 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581 std::get<1>(it));
582 }
583 writer.append(op.types());
584 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586 // Simply repoint the memory index of the result to the constant.
587 getMemIndex(op.result()) = getMemIndex(op.value());
588 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590 writer.append(OpCode::EraseOp, op.operation());
591 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593 writer.append(OpCode::Finalize);
594 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)595 void Generator::generate(pdl_interp::GetAttributeOp op,
596 ByteCodeWriter &writer) {
597 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598 Identifier::get(op.name(), ctx));
599 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)600 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601 ByteCodeWriter &writer) {
602 writer.append(OpCode::GetAttributeType, op.result(), op.value());
603 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)604 void Generator::generate(pdl_interp::GetDefiningOpOp op,
605 ByteCodeWriter &writer) {
606 writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609 uint32_t index = op.index();
610 if (index < 4)
611 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612 else
613 writer.append(OpCode::GetOperandN, index);
614 writer.append(op.operation(), op.value());
615 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617 uint32_t index = op.index();
618 if (index < 4)
619 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620 else
621 writer.append(OpCode::GetResultN, index);
622 writer.append(op.operation(), op.value());
623 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)624 void Generator::generate(pdl_interp::GetValueTypeOp op,
625 ByteCodeWriter &writer) {
626 writer.append(OpCode::GetValueType, op.result(), op.value());
627 }
generate(pdl_interp::InferredTypeOp op,ByteCodeWriter & writer)628 void Generator::generate(pdl_interp::InferredTypeOp op,
629 ByteCodeWriter &writer) {
630 // InferType maps to a null type as a marker for inferring a result type.
631 getMemIndex(op.type()) = getMemIndex(Type());
632 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637 ByteCodeField patternIndex = patterns.size();
638 patterns.emplace_back(PDLByteCodePattern::create(
639 op, rewriterToAddr[op.rewriter().getLeafReference()]));
640 writer.append(OpCode::RecordMatch, patternIndex,
641 SuccessorRange(op.getOperation()), op.matchedOps(),
642 op.inputs());
643 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645 writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)647 void Generator::generate(pdl_interp::SwitchAttributeOp op,
648 ByteCodeWriter &writer) {
649 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650 op.getSuccessors());
651 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)652 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653 ByteCodeWriter &writer) {
654 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655 op.getSuccessors());
656 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)657 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658 ByteCodeWriter &writer) {
659 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660 return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661 });
662 writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663 op.getSuccessors());
664 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)665 void Generator::generate(pdl_interp::SwitchResultCountOp op,
666 ByteCodeWriter &writer) {
667 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668 op.getSuccessors());
669 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672 op.getSuccessors());
673 }
674
675 //===----------------------------------------------------------------------===//
676 // PDLByteCode
677 //===----------------------------------------------------------------------===//
678
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLCreateFunction> createFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)679 PDLByteCode::PDLByteCode(ModuleOp module,
680 llvm::StringMap<PDLConstraintFunction> constraintFns,
681 llvm::StringMap<PDLCreateFunction> createFns,
682 llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683 Generator generator(module.getContext(), uniquedData, matcherByteCode,
684 rewriterByteCode, patterns, maxValueMemoryIndex,
685 constraintFns, createFns, rewriteFns);
686 generator.generate(module);
687
688 // Initialize the external functions.
689 for (auto &it : constraintFns)
690 constraintFunctions.push_back(std::move(it.second));
691 for (auto &it : createFns)
692 createFunctions.push_back(std::move(it.second));
693 for (auto &it : rewriteFns)
694 rewriteFunctions.push_back(std::move(it.second));
695 }
696
697 /// Initialize the given state such that it can be used to execute the current
698 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700 state.memory.resize(maxValueMemoryIndex, nullptr);
701 state.currentPatternBenefits.reserve(patterns.size());
702 for (const PDLByteCodePattern &pattern : patterns)
703 state.currentPatternBenefits.push_back(pattern.getBenefit());
704 }
705
706 //===----------------------------------------------------------------------===//
707 // ByteCode Execution
708
709 namespace {
710 /// This class provides support for executing a bytecode stream.
711 class ByteCodeExecutor {
712 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLCreateFunction> createFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)713 ByteCodeExecutor(const ByteCodeField *curCodeIt,
714 MutableArrayRef<const void *> memory,
715 ArrayRef<const void *> uniquedMemory,
716 ArrayRef<ByteCodeField> code,
717 ArrayRef<PatternBenefit> currentPatternBenefits,
718 ArrayRef<PDLByteCodePattern> patterns,
719 ArrayRef<PDLConstraintFunction> constraintFunctions,
720 ArrayRef<PDLCreateFunction> createFunctions,
721 ArrayRef<PDLRewriteFunction> rewriteFunctions)
722 : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723 code(code), currentPatternBenefits(currentPatternBenefits),
724 patterns(patterns), constraintFunctions(constraintFunctions),
725 createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726
727 /// Start executing the code at the current bytecode index. `matches` is an
728 /// optional field provided when this function is executed in a matching
729 /// context.
730 void execute(PatternRewriter &rewriter,
731 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732 Optional<Location> mainRewriteLoc = {});
733
734 private:
735 /// Read a value from the bytecode buffer, optionally skipping a certain
736 /// number of prefix values. These methods always update the buffer to point
737 /// to the next field after the read data.
738 template <typename T = ByteCodeField>
read(size_t skipN=0)739 T read(size_t skipN = 0) {
740 curCodeIt += skipN;
741 return readImpl<T>();
742 }
read(size_t skipN=0)743 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
744
745 /// Read a list of values from the bytecode buffer.
746 template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)747 void readList(SmallVectorImpl<T> &list) {
748 list.clear();
749 for (unsigned i = 0, e = read(); i != e; ++i)
750 list.push_back(read<ValueT>());
751 }
752
753 /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)754 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
755 /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)756 void selectJump(size_t destIndex) {
757 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
758 }
759
760 /// Handle a switch operation with the provided value and cases.
761 template <typename T, typename RangeT>
handleSwitch(const T & value,RangeT && cases)762 void handleSwitch(const T &value, RangeT &&cases) {
763 LLVM_DEBUG({
764 llvm::dbgs() << " * Value: " << value << "\n"
765 << " * Cases: ";
766 llvm::interleaveComma(cases, llvm::dbgs());
767 llvm::dbgs() << "\n\n";
768 });
769
770 // Check to see if the attribute value is within the case list. Jump to
771 // the correct successor index based on the result.
772 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
773 if (*it == value)
774 return selectJump(size_t((it - cases.begin()) + 1));
775 selectJump(size_t(0));
776 }
777
778 /// Internal implementation of reading various data types from the bytecode
779 /// stream.
780 template <typename T>
readFromMemory()781 const void *readFromMemory() {
782 size_t index = *curCodeIt++;
783
784 // If this type is an SSA value, it can only be stored in non-const memory.
785 if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
786 return memory[index];
787
788 // Otherwise, if this index is not inbounds it is uniqued.
789 return uniquedMemory[index - memory.size()];
790 }
791 template <typename T>
readImpl()792 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
793 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
794 }
795 template <typename T>
796 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
797 T>
readImpl()798 readImpl() {
799 return T(T::getFromOpaquePointer(readFromMemory<T>()));
800 }
801 template <typename T>
readImpl()802 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
803 switch (static_cast<PDLValueKind>(read())) {
804 case PDLValueKind::Attribute:
805 return read<Attribute>();
806 case PDLValueKind::Operation:
807 return read<Operation *>();
808 case PDLValueKind::Type:
809 return read<Type>();
810 case PDLValueKind::Value:
811 return read<Value>();
812 }
813 }
814 template <typename T>
readImpl()815 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
816 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
817 "unexpected ByteCode address size");
818 ByteCodeAddr result;
819 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
820 curCodeIt += 2;
821 return result;
822 }
823 template <typename T>
readImpl()824 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
825 return *curCodeIt++;
826 }
827
828 /// The underlying bytecode buffer.
829 const ByteCodeField *curCodeIt;
830
831 /// The current execution memory.
832 MutableArrayRef<const void *> memory;
833
834 /// References to ByteCode data necessary for execution.
835 ArrayRef<const void *> uniquedMemory;
836 ArrayRef<ByteCodeField> code;
837 ArrayRef<PatternBenefit> currentPatternBenefits;
838 ArrayRef<PDLByteCodePattern> patterns;
839 ArrayRef<PDLConstraintFunction> constraintFunctions;
840 ArrayRef<PDLCreateFunction> createFunctions;
841 ArrayRef<PDLRewriteFunction> rewriteFunctions;
842 };
843 } // end anonymous namespace
844
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)845 void ByteCodeExecutor::execute(
846 PatternRewriter &rewriter,
847 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
848 Optional<Location> mainRewriteLoc) {
849 while (true) {
850 OpCode opCode = static_cast<OpCode>(read());
851 switch (opCode) {
852 case ApplyConstraint: {
853 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
854 const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
855 ArrayAttr constParams = read<ArrayAttr>();
856 SmallVector<PDLValue, 16> args;
857 readList<PDLValue>(args);
858 LLVM_DEBUG({
859 llvm::dbgs() << " * Arguments: ";
860 llvm::interleaveComma(args, llvm::dbgs());
861 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
862 });
863
864 // Invoke the constraint and jump to the proper destination.
865 selectJump(succeeded(constraintFn(args, constParams, rewriter)));
866 break;
867 }
868 case ApplyRewrite: {
869 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
870 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
871 ArrayAttr constParams = read<ArrayAttr>();
872 Operation *root = read<Operation *>();
873 SmallVector<PDLValue, 16> args;
874 readList<PDLValue>(args);
875
876 LLVM_DEBUG({
877 llvm::dbgs() << " * Root: " << *root << "\n"
878 << " * Arguments: ";
879 llvm::interleaveComma(args, llvm::dbgs());
880 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
881 });
882 rewriteFn(root, args, constParams, rewriter);
883 break;
884 }
885 case AreEqual: {
886 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
887 const void *lhs = read<const void *>();
888 const void *rhs = read<const void *>();
889
890 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
891 selectJump(lhs == rhs);
892 break;
893 }
894 case Branch: {
895 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
896 curCodeIt = &code[read<ByteCodeAddr>()];
897 break;
898 }
899 case CheckOperandCount: {
900 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
901 Operation *op = read<Operation *>();
902 uint32_t expectedCount = read<uint32_t>();
903
904 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
905 << " * Expected: " << expectedCount << "\n\n");
906 selectJump(op->getNumOperands() == expectedCount);
907 break;
908 }
909 case CheckOperationName: {
910 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
911 Operation *op = read<Operation *>();
912 OperationName expectedName = read<OperationName>();
913
914 LLVM_DEBUG(llvm::dbgs()
915 << " * Found: \"" << op->getName() << "\"\n"
916 << " * Expected: \"" << expectedName << "\"\n\n");
917 selectJump(op->getName() == expectedName);
918 break;
919 }
920 case CheckResultCount: {
921 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
922 Operation *op = read<Operation *>();
923 uint32_t expectedCount = read<uint32_t>();
924
925 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
926 << " * Expected: " << expectedCount << "\n\n");
927 selectJump(op->getNumResults() == expectedCount);
928 break;
929 }
930 case CreateNative: {
931 LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
932 const PDLCreateFunction &createFn = createFunctions[read()];
933 ByteCodeField resultIndex = read();
934 ArrayAttr constParams = read<ArrayAttr>();
935 SmallVector<PDLValue, 16> args;
936 readList<PDLValue>(args);
937
938 LLVM_DEBUG({
939 llvm::dbgs() << " * Arguments: ";
940 llvm::interleaveComma(args, llvm::dbgs());
941 llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
942 });
943
944 PDLValue result = createFn(args, constParams, rewriter);
945 memory[resultIndex] = result.getAsOpaquePointer();
946
947 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
948 break;
949 }
950 case CreateOperation: {
951 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
952 assert(mainRewriteLoc && "expected rewrite loc to be provided when "
953 "executing the rewriter bytecode");
954
955 unsigned memIndex = read();
956 OperationState state(*mainRewriteLoc, read<OperationName>());
957 readList<Value>(state.operands);
958 for (unsigned i = 0, e = read(); i != e; ++i) {
959 Identifier name = read<Identifier>();
960 if (Attribute attr = read<Attribute>())
961 state.addAttribute(name, attr);
962 }
963
964 bool hasInferredTypes = false;
965 for (unsigned i = 0, e = read(); i != e; ++i) {
966 Type resultType = read<Type>();
967 hasInferredTypes |= !resultType;
968 state.types.push_back(resultType);
969 }
970
971 // Handle the case where the operation has inferred types.
972 if (hasInferredTypes) {
973 InferTypeOpInterface::Concept *concept =
974 state.name.getAbstractOperation()
975 ->getInterface<InferTypeOpInterface>();
976
977 // TODO: Handle failure.
978 SmallVector<Type, 2> inferredTypes;
979 if (failed(concept->inferReturnTypes(
980 state.getContext(), state.location, state.operands,
981 state.attributes.getDictionary(state.getContext()),
982 state.regions, inferredTypes)))
983 return;
984
985 for (unsigned i = 0, e = state.types.size(); i != e; ++i)
986 if (!state.types[i])
987 state.types[i] = inferredTypes[i];
988 }
989 Operation *resultOp = rewriter.createOperation(state);
990 memory[memIndex] = resultOp;
991
992 LLVM_DEBUG({
993 llvm::dbgs() << " * Attributes: "
994 << state.attributes.getDictionary(state.getContext())
995 << "\n * Operands: ";
996 llvm::interleaveComma(state.operands, llvm::dbgs());
997 llvm::dbgs() << "\n * Result Types: ";
998 llvm::interleaveComma(state.types, llvm::dbgs());
999 llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
1000 });
1001 break;
1002 }
1003 case EraseOp: {
1004 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1005 Operation *op = read<Operation *>();
1006
1007 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
1008 rewriter.eraseOp(op);
1009 break;
1010 }
1011 case Finalize: {
1012 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1013 return;
1014 }
1015 case GetAttribute: {
1016 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1017 unsigned memIndex = read();
1018 Operation *op = read<Operation *>();
1019 Identifier attrName = read<Identifier>();
1020 Attribute attr = op->getAttr(attrName);
1021
1022 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1023 << " * Attribute: " << attrName << "\n"
1024 << " * Result: " << attr << "\n\n");
1025 memory[memIndex] = attr.getAsOpaquePointer();
1026 break;
1027 }
1028 case GetAttributeType: {
1029 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1030 unsigned memIndex = read();
1031 Attribute attr = read<Attribute>();
1032
1033 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1034 << " * Result: " << attr.getType() << "\n\n");
1035 memory[memIndex] = attr.getType().getAsOpaquePointer();
1036 break;
1037 }
1038 case GetDefiningOp: {
1039 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1040 unsigned memIndex = read();
1041 Value value = read<Value>();
1042 Operation *op = value ? value.getDefiningOp() : nullptr;
1043
1044 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1045 << " * Result: " << *op << "\n\n");
1046 memory[memIndex] = op;
1047 break;
1048 }
1049 case GetOperand0:
1050 case GetOperand1:
1051 case GetOperand2:
1052 case GetOperand3:
1053 case GetOperandN: {
1054 LLVM_DEBUG({
1055 llvm::dbgs() << "Executing GetOperand"
1056 << (opCode == GetOperandN ? Twine("N")
1057 : Twine(opCode - GetOperand0))
1058 << ":\n";
1059 });
1060 unsigned index =
1061 opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
1062 Operation *op = read<Operation *>();
1063 unsigned memIndex = read();
1064 Value operand =
1065 index < op->getNumOperands() ? op->getOperand(index) : Value();
1066
1067 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1068 << " * Index: " << index << "\n"
1069 << " * Result: " << operand << "\n\n");
1070 memory[memIndex] = operand.getAsOpaquePointer();
1071 break;
1072 }
1073 case GetResult0:
1074 case GetResult1:
1075 case GetResult2:
1076 case GetResult3:
1077 case GetResultN: {
1078 LLVM_DEBUG({
1079 llvm::dbgs() << "Executing GetResult"
1080 << (opCode == GetResultN ? Twine("N")
1081 : Twine(opCode - GetResult0))
1082 << ":\n";
1083 });
1084 unsigned index =
1085 opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
1086 Operation *op = read<Operation *>();
1087 unsigned memIndex = read();
1088 OpResult result =
1089 index < op->getNumResults() ? op->getResult(index) : OpResult();
1090
1091 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1092 << " * Index: " << index << "\n"
1093 << " * Result: " << result << "\n\n");
1094 memory[memIndex] = result.getAsOpaquePointer();
1095 break;
1096 }
1097 case GetValueType: {
1098 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1099 unsigned memIndex = read();
1100 Value value = read<Value>();
1101
1102 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1103 << " * Result: " << value.getType() << "\n\n");
1104 memory[memIndex] = value.getType().getAsOpaquePointer();
1105 break;
1106 }
1107 case IsNotNull: {
1108 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1109 const void *value = read<const void *>();
1110
1111 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
1112 selectJump(value != nullptr);
1113 break;
1114 }
1115 case RecordMatch: {
1116 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1117 assert(matches &&
1118 "expected matches to be provided when executing the matcher");
1119 unsigned patternIndex = read();
1120 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1121 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1122
1123 // If the benefit of the pattern is impossible, skip the processing of the
1124 // rest of the pattern.
1125 if (benefit.isImpossibleToMatch()) {
1126 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
1127 curCodeIt = dest;
1128 break;
1129 }
1130
1131 // Create a fused location containing the locations of each of the
1132 // operations used in the match. This will be used as the location for
1133 // created operations during the rewrite that don't already have an
1134 // explicit location set.
1135 unsigned numMatchLocs = read();
1136 SmallVector<Location, 4> matchLocs;
1137 matchLocs.reserve(numMatchLocs);
1138 for (unsigned i = 0; i != numMatchLocs; ++i)
1139 matchLocs.push_back(read<Operation *>()->getLoc());
1140 Location matchLoc = rewriter.getFusedLoc(matchLocs);
1141
1142 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1143 << " * Location: " << matchLoc << "\n\n");
1144 matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
1145 readList<const void *>(matches->back().values);
1146 curCodeIt = dest;
1147 break;
1148 }
1149 case ReplaceOp: {
1150 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1151 Operation *op = read<Operation *>();
1152 SmallVector<Value, 16> args;
1153 readList<Value>(args);
1154
1155 LLVM_DEBUG({
1156 llvm::dbgs() << " * Operation: " << *op << "\n"
1157 << " * Values: ";
1158 llvm::interleaveComma(args, llvm::dbgs());
1159 llvm::dbgs() << "\n\n";
1160 });
1161 rewriter.replaceOp(op, args);
1162 break;
1163 }
1164 case SwitchAttribute: {
1165 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1166 Attribute value = read<Attribute>();
1167 ArrayAttr cases = read<ArrayAttr>();
1168 handleSwitch(value, cases);
1169 break;
1170 }
1171 case SwitchOperandCount: {
1172 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1173 Operation *op = read<Operation *>();
1174 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1175
1176 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1177 handleSwitch(op->getNumOperands(), cases);
1178 break;
1179 }
1180 case SwitchOperationName: {
1181 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1182 OperationName value = read<Operation *>()->getName();
1183 size_t caseCount = read();
1184
1185 // The operation names are stored in-line, so to print them out for
1186 // debugging purposes we need to read the array before executing the
1187 // switch so that we can display all of the possible values.
1188 LLVM_DEBUG({
1189 const ByteCodeField *prevCodeIt = curCodeIt;
1190 llvm::dbgs() << " * Value: " << value << "\n"
1191 << " * Cases: ";
1192 llvm::interleaveComma(
1193 llvm::map_range(llvm::seq<size_t>(0, caseCount),
1194 [&](size_t i) { return read<OperationName>(); }),
1195 llvm::dbgs());
1196 llvm::dbgs() << "\n\n";
1197 curCodeIt = prevCodeIt;
1198 });
1199
1200 // Try to find the switch value within any of the cases.
1201 size_t jumpDest = 0;
1202 for (size_t i = 0; i != caseCount; ++i) {
1203 if (read<OperationName>() == value) {
1204 curCodeIt += (caseCount - i - 1);
1205 jumpDest = i + 1;
1206 break;
1207 }
1208 }
1209 selectJump(jumpDest);
1210 break;
1211 }
1212 case SwitchResultCount: {
1213 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1214 Operation *op = read<Operation *>();
1215 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1216
1217 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1218 handleSwitch(op->getNumResults(), cases);
1219 break;
1220 }
1221 case SwitchType: {
1222 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1223 Type value = read<Type>();
1224 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1225 handleSwitch(value, cases);
1226 break;
1227 }
1228 }
1229 }
1230 }
1231
1232 /// Run the pattern matcher on the given root operation, collecting the matched
1233 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1234 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1235 SmallVectorImpl<MatchResult> &matches,
1236 PDLByteCodeMutableState &state) const {
1237 // The first memory slot is always the root operation.
1238 state.memory[0] = op;
1239
1240 // The matcher function always starts at code address 0.
1241 ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1242 matcherByteCode, state.currentPatternBenefits,
1243 patterns, constraintFunctions, createFunctions,
1244 rewriteFunctions);
1245 executor.execute(rewriter, &matches);
1246
1247 // Order the found matches by benefit.
1248 std::stable_sort(matches.begin(), matches.end(),
1249 [](const MatchResult &lhs, const MatchResult &rhs) {
1250 return lhs.benefit > rhs.benefit;
1251 });
1252 }
1253
1254 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1255 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1256 PDLByteCodeMutableState &state) const {
1257 // The arguments of the rewrite function are stored at the start of the
1258 // memory buffer.
1259 llvm::copy(match.values, state.memory.begin());
1260
1261 ByteCodeExecutor executor(
1262 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1263 uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1264 constraintFunctions, createFunctions, rewriteFunctions);
1265 executor.execute(rewriter, /*matches=*/nullptr, match.location);
1266 }
1267