//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements MLIR to byte-code generation and the interpreter. // //===----------------------------------------------------------------------===// #include "ByteCode.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/RegionGraphTraits.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "pdl-bytecode" using namespace mlir; using namespace mlir::detail; //===----------------------------------------------------------------------===// // PDLByteCodePattern //===----------------------------------------------------------------------===// PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr) { SmallVector generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange()); PatternBenefit benefit = matchOp.benefit(); MLIRContext *ctx = matchOp.getContext(); // Check to see if this is pattern matches a specific operation type. if (Optional rootKind = matchOp.rootKind()) return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, ctx); return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, MatchAnyOpTypeTag()); } //===----------------------------------------------------------------------===// // PDLByteCodeMutableState //===----------------------------------------------------------------------===// /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds /// to the position of the pattern within the range returned by /// `PDLByteCode::getPatterns`. void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) { currentPatternBenefits[patternIndex] = benefit; } //===----------------------------------------------------------------------===// // Bytecode OpCodes //===----------------------------------------------------------------------===// namespace { enum OpCode : ByteCodeField { /// Apply an externally registered constraint. ApplyConstraint, /// Apply an externally registered rewrite. ApplyRewrite, /// Check if two generic values are equal. AreEqual, /// Unconditional branch. Branch, /// Compare the operand count of an operation with a constant. CheckOperandCount, /// Compare the name of an operation with a constant. CheckOperationName, /// Compare the result count of an operation with a constant. CheckResultCount, /// Invoke a native creation method. CreateNative, /// Create an operation. CreateOperation, /// Erase an operation. EraseOp, /// Terminate a matcher or rewrite sequence. Finalize, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. GetAttributeType, /// Get the defining operation of a value. GetDefiningOp, /// Get a specific operand of an operation. GetOperand0, GetOperand1, GetOperand2, GetOperand3, GetOperandN, /// Get a specific result of an operation. GetResult0, GetResult1, GetResult2, GetResult3, GetResultN, /// Get the type of a value. GetValueType, /// Check if a generic value is not null. IsNotNull, /// Record a successful pattern match. RecordMatch, /// Replace an operation. ReplaceOp, /// Compare an attribute with a set of constants. SwitchAttribute, /// Compare the operand count of an operation with a set of constants. SwitchOperandCount, /// Compare the name of an operation with a set of constants. SwitchOperationName, /// Compare the result count of an operation with a set of constants. SwitchResultCount, /// Compare a type with a set of constants. SwitchType, }; enum class PDLValueKind { Attribute, Operation, Type, Value }; } // end anonymous namespace //===----------------------------------------------------------------------===// // ByteCode Generation //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Generator namespace { struct ByteCodeWriter; /// This class represents the main generator for the pattern bytecode. class Generator { public: Generator(MLIRContext *ctx, std::vector &uniquedData, SmallVectorImpl &matcherByteCode, SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, llvm::StringMap &constraintFns, llvm::StringMap &createFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(createFns)) nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); } /// Generate the bytecode for the given PDL interpreter module. void generate(ModuleOp module); /// Return the memory index to use for the given value. ByteCodeField &getMemIndex(Value value) { assert(valueToMemIndex.count(value) && "expected memory index to be assigned"); return valueToMemIndex[value]; } /// Return an index to use when referring to the given data that is uniqued in /// the MLIR context. template std::enable_if_t::value, ByteCodeField &> getMemIndex(T val) { const void *opaqueVal = val.getAsOpaquePointer(); // Get or insert a reference to this value. auto it = uniquedDataToMemIndex.try_emplace( opaqueVal, maxValueMemoryIndex + uniquedData.size()); if (it.second) uniquedData.push_back(opaqueVal); return it.first->second; } private: /// Allocate memory indices for the results of operations within the matcher /// and rewriters. void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); /// Mapping from value to its corresponding memory index. DenseMap valueToMemIndex; /// Mapping from the name of an externally registered rewrite to its index in /// the bytecode registry. llvm::StringMap externalRewriterToMemIndex; /// Mapping from the name of an externally registered constraint to its index /// in the bytecode registry. llvm::StringMap constraintToMemIndex; /// Mapping from the name of an externally registered creation method to its /// index in the bytecode registry. llvm::StringMap nativeCreateToMemIndex; /// Mapping from rewriter function name to the bytecode address of the /// rewriter function in byte. llvm::StringMap rewriterToAddr; /// Mapping from a uniqued storage object to its memory index within /// `uniquedData`. DenseMap uniquedDataToMemIndex; /// The current MLIR context. MLIRContext *ctx; /// Data of the ByteCode class to be populated. std::vector &uniquedData; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; }; /// This class provides utilities for writing a bytecode stream. struct ByteCodeWriter { ByteCodeWriter(SmallVectorImpl &bytecode, Generator &generator) : bytecode(bytecode), generator(generator) {} /// Append a field to the bytecode. void append(ByteCodeField field) { bytecode.push_back(field); } void append(OpCode opCode) { bytecode.push_back(opCode); } /// Append an address to the bytecode. void append(ByteCodeAddr field) { static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, "unexpected ByteCode address size"); ByteCodeField fieldParts[2]; std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); bytecode.append({fieldParts[0], fieldParts[1]}); } /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { // Add back references to the any successors so that the address can be // resolved later. for (Block *successor : successors) { unresolvedSuccessorRefs[successor].push_back(bytecode.size()); append(ByteCodeAddr(0)); } } /// Append a range of values that will be read as generic PDLValues. void appendPDLValueList(OperandRange values) { bytecode.push_back(values.size()); for (Value value : values) { // Append the type of the value in addition to the value itself. PDLValueKind kind = TypeSwitch(value.getType()) .Case( [](Type) { return PDLValueKind::Attribute; }) .Case( [](Type) { return PDLValueKind::Operation; }) .Case([](Type) { return PDLValueKind::Type; }) .Case([](Type) { return PDLValueKind::Value; }); bytecode.push_back(static_cast(kind)); append(value); } } /// Check if the given class `T` has an iterator type. template using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template std::enable_if_t::value || std::is_pointer::value> append(T value) { bytecode.push_back(generator.getMemIndex(value)); } /// Append a range of values. template > std::enable_if_t::value> append(T range) { bytecode.push_back(llvm::size(range)); for (auto it : range) append(it); } /// Append a variadic number of fields to the bytecode. template void append(FieldTy field, Field2Ty field2, FieldTys... fields) { append(field); append(field2, fields...); } /// Successor references in the bytecode that have yet to be resolved. DenseMap> unresolvedSuccessorRefs; /// The underlying bytecode buffer. SmallVectorImpl &bytecode; /// The main generator producing PDL. Generator &generator; }; } // end anonymous namespace void Generator::generate(ModuleOp module) { FuncOp matcherFunc = module.lookupSymbol( pdl_interp::PDLInterpDialect::getMatcherFunctionName()); ModuleOp rewriterModule = module.lookupSymbol( pdl_interp::PDLInterpDialect::getRewriterModuleName()); assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); // Allocate memory indices for the results of operations within the matcher // and rewriters. allocateMemoryIndices(matcherFunc, rewriterModule); // Generate code for the rewriter functions. ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); for (FuncOp rewriterFunc : rewriterModule.getOps()) { rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); for (Operation &op : rewriterFunc.getOps()) generate(&op, rewriterByteCodeWriter); } assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && "unexpected branches in rewriter function"); // Generate code for the matcher function. DenseMap blockToAddr; llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); for (Block *block : rpot) { // Keep track of where this block begins within the matcher function. blockToAddr.try_emplace(block, matcherByteCode.size()); for (Operation &op : *block) generate(&op, matcherByteCodeWriter); } // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { ByteCodeAddr addr = blockToAddr[it.first]; for (unsigned offsetToFix : it.second) std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); } } void Generator::allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule) { // Rewriters use simplistic allocation scheme that simply assigns an index to // each result. for (FuncOp rewriterFunc : rewriterModule.getOps()) { ByteCodeField index = 0; for (BlockArgument arg : rewriterFunc.getArguments()) valueToMemIndex.try_emplace(arg, index++); rewriterFunc.getBody().walk([&](Operation *op) { for (Value result : op->getResults()) valueToMemIndex.try_emplace(result, index++); }); if (index > maxValueMemoryIndex) maxValueMemoryIndex = index; } // The matcher function uses a more sophisticated numbering that tries to // minimize the number of memory indices assigned. This is done by determining // a live range of the values within the matcher, then the allocation is just // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. DenseMap opToIndex; matcherFunc.getBody().walk([&](Operation *op) { opToIndex.insert(std::make_pair(op, opToIndex.size())); }); // Liveness info for each of the defs within the matcher. using LivenessSet = llvm::IntervalMap; LivenessSet::Allocator allocator; DenseMap valueDefRanges; // Assign the root operation being matched to slot 0. BlockArgument rootOpArg = matcherFunc.getArgument(0); valueToMemIndex[rootOpArg] = 0; // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); for (Block &block : matcherFunc.getBody()) { const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always // assigned to the first memory slot. if (value == rootOpArg) return; // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; defRangeIt->second.insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); }; // Process the live-ins of this block. for (Value liveIn : info->in()) processValue(liveIn, &block.front()); // Process any new defs within this block. for (Operation &op : block) for (Value result : op.getResults()) processValue(result, &op); } // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; LivenessSet &defSet = defIt.second; // Try to allocate to an existing index. for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { LivenessSet &existingIndex = existingIndexIt.value(); llvm::IntervalMapOverlaps overlaps( defIt.second, existingIndex); if (overlaps.valid()) continue; // Union the range of the def within the existing index. for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); memIndex = existingIndexIt.index() + 1; } // If no existing index could be used, add a new one. if (memIndex == 0) { allocatedIndices.emplace_back(allocator); for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); memIndex = allocatedIndices.size(); } } // Update the max number of indices. ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; if (numMatcherIndices > maxValueMemoryIndex) maxValueMemoryIndex = numMatcherIndices; } void Generator::generate(Operation *op, ByteCodeWriter &writer) { TypeSwitch(op) .Case( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); }); } void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { assert(constraintToMemIndex.count(op.name()) && "expected index for constraint function"); writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], op.constParamsAttr()); writer.appendPDLValueList(op.args()); writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer) { assert(externalRewriterToMemIndex.count(op.name()) && "expected index for rewrite function"); writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], op.constParamsAttr(), op.root()); writer.appendPDLValueList(op.args()); } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); } void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); } void Generator::generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperationName, op.operation(), OperationName(op.name(), ctx), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckResultCount, op.operation(), op.count(), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); } void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.attribute()) = getMemIndex(op.value()); } void Generator::generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer) { assert(nativeCreateToMemIndex.count(op.name()) && "expected index for creation function"); writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], op.result(), op.constParamsAttr()); writer.appendPDLValueList(op.args()); } void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateOperation, op.operation(), OperationName(op.name(), ctx), op.operands()); // Add the attributes. OperandRange attributes = op.attributes(); writer.append(static_cast(attributes.size())); for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { writer.append( Identifier::get(std::get<0>(it).cast().getValue(), ctx), std::get<1>(it)); } writer.append(op.types()); } void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.result()) = getMemIndex(op.value()); } void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { writer.append(OpCode::Finalize); } void Generator::generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), Identifier::get(op.name(), ctx)); } void Generator::generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttributeType, op.result(), op.value()); } void Generator::generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); } void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) writer.append(static_cast(OpCode::GetOperand0 + index)); else writer.append(OpCode::GetOperandN, index); writer.append(op.operation(), op.value()); } void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) writer.append(static_cast(OpCode::GetResult0 + index)); else writer.append(OpCode::GetResultN, index); writer.append(op.operation(), op.value()); } void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetValueType, op.result(), op.value()); } void Generator::generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer) { // InferType maps to a null type as a marker for inferring a result type. getMemIndex(op.type()) = getMemIndex(Type()); } void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); } void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( op, rewriterToAddr[op.rewriter().getLeafReference()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.matchedOps(), op.inputs()); } void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); } void Generator::generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer) { auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { return OperationName(attr.cast().getValue(), ctx); }); writer.append(OpCode::SwitchOperationName, op.operation(), cases, op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), op.getSuccessors()); } void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), op.getSuccessors()); } //===----------------------------------------------------------------------===// // PDLByteCode //===----------------------------------------------------------------------===// PDLByteCode::PDLByteCode(ModuleOp module, llvm::StringMap constraintFns, llvm::StringMap createFns, llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, constraintFns, createFns, rewriteFns); generator.generate(module); // Initialize the external functions. for (auto &it : constraintFns) constraintFunctions.push_back(std::move(it.second)); for (auto &it : createFns) createFunctions.push_back(std::move(it.second)); for (auto &it : rewriteFns) rewriteFunctions.push_back(std::move(it.second)); } /// Initialize the given state such that it can be used to execute the current /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); } //===----------------------------------------------------------------------===// // ByteCode Execution namespace { /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: ByteCodeExecutor(const ByteCodeField *curCodeIt, MutableArrayRef memory, ArrayRef uniquedMemory, ArrayRef code, ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, ArrayRef createFunctions, ArrayRef rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching /// context. void execute(PatternRewriter &rewriter, SmallVectorImpl *matches = nullptr, Optional mainRewriteLoc = {}); private: /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. template T read(size_t skipN = 0) { curCodeIt += skipN; return readImpl(); } ByteCodeField read(size_t skipN = 0) { return read(skipN); } /// Read a list of values from the bytecode buffer. template void readList(SmallVectorImpl &list) { list.clear(); for (unsigned i = 0, e = read(); i != e; ++i) list.push_back(read()); } /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. void selectJump(size_t destIndex) { curCodeIt = &code[read(destIndex * 2)]; } /// Handle a switch operation with the provided value and cases. template void handleSwitch(const T &value, RangeT &&cases) { LLVM_DEBUG({ llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; llvm::interleaveComma(cases, llvm::dbgs()); llvm::dbgs() << "\n\n"; }); // Check to see if the attribute value is within the case list. Jump to // the correct successor index based on the result. for (auto it = cases.begin(), e = cases.end(); it != e; ++it) if (*it == value) return selectJump(size_t((it - cases.begin()) + 1)); selectJump(size_t(0)); } /// Internal implementation of reading various data types from the bytecode /// stream. template const void *readFromMemory() { size_t index = *curCodeIt++; // If this type is an SSA value, it can only be stored in non-const memory. if (llvm::is_one_of::value || index < memory.size()) return memory[index]; // Otherwise, if this index is not inbounds it is uniqued. return uniquedMemory[index - memory.size()]; } template std::enable_if_t::value, T> readImpl() { return reinterpret_cast(const_cast(readFromMemory())); } template std::enable_if_t::value && !std::is_same::value, T> readImpl() { return T(T::getFromOpaquePointer(readFromMemory())); } template std::enable_if_t::value, T> readImpl() { switch (static_cast(read())) { case PDLValueKind::Attribute: return read(); case PDLValueKind::Operation: return read(); case PDLValueKind::Type: return read(); case PDLValueKind::Value: return read(); } } template std::enable_if_t::value, T> readImpl() { static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, "unexpected ByteCode address size"); ByteCodeAddr result; std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); curCodeIt += 2; return result; } template std::enable_if_t::value, T> readImpl() { return *curCodeIt++; } /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; /// The current execution memory. MutableArrayRef memory; /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; ArrayRef code; ArrayRef currentPatternBenefits; ArrayRef patterns; ArrayRef constraintFunctions; ArrayRef createFunctions; ArrayRef rewriteFunctions; }; } // end anonymous namespace void ByteCodeExecutor::execute( PatternRewriter &rewriter, SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; }); // Invoke the constraint and jump to the proper destination. selectJump(succeeded(constraintFn(args, constParams, rewriter))); break; } case ApplyRewrite: { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; ArrayAttr constParams = read(); Operation *root = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Root: " << *root << "\n" << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; }); rewriteFn(root, args, constParams, rewriter); break; } case AreEqual: { LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); const void *lhs = read(); const void *rhs = read(); LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); selectJump(lhs == rhs); break; } case Branch: { LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); curCodeIt = &code[read()]; break; } case CheckOperandCount: { LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" << " * Expected: " << expectedCount << "\n\n"); selectJump(op->getNumOperands() == expectedCount); break; } case CheckOperationName: { LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); Operation *op = read(); OperationName expectedName = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" << " * Expected: \"" << expectedName << "\"\n\n"); selectJump(op->getName() == expectedName); break; } case CheckResultCount: { LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" << " * Expected: " << expectedCount << "\n\n"); selectJump(op->getNumResults() == expectedCount); break; } case CreateNative: { LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); const PDLCreateFunction &createFn = createFunctions[read()]; ByteCodeField resultIndex = read(); ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); PDLValue result = createFn(args, constParams, rewriter); memory[resultIndex] = result.getAsOpaquePointer(); LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); break; } case CreateOperation: { LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); assert(mainRewriteLoc && "expected rewrite loc to be provided when " "executing the rewriter bytecode"); unsigned memIndex = read(); OperationState state(*mainRewriteLoc, read()); readList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { Identifier name = read(); if (Attribute attr = read()) state.addAttribute(name, attr); } bool hasInferredTypes = false; for (unsigned i = 0, e = read(); i != e; ++i) { Type resultType = read(); hasInferredTypes |= !resultType; state.types.push_back(resultType); } // Handle the case where the operation has inferred types. if (hasInferredTypes) { InferTypeOpInterface::Concept *concept = state.name.getAbstractOperation() ->getInterface(); // TODO: Handle failure. SmallVector inferredTypes; if (failed(concept->inferReturnTypes( state.getContext(), state.location, state.operands, state.attributes.getDictionary(state.getContext()), state.regions, inferredTypes))) return; for (unsigned i = 0, e = state.types.size(); i != e; ++i) if (!state.types[i]) state.types[i] = inferredTypes[i]; } Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; LLVM_DEBUG({ llvm::dbgs() << " * Attributes: " << state.attributes.getDictionary(state.getContext()) << "\n * Operands: "; llvm::interleaveComma(state.operands, llvm::dbgs()); llvm::dbgs() << "\n * Result Types: "; llvm::interleaveComma(state.types, llvm::dbgs()); llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; }); break; } case EraseOp: { LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); Operation *op = read(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); rewriter.eraseOp(op); break; } case Finalize: { LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); return; } case GetAttribute: { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); Operation *op = read(); Identifier attrName = read(); Attribute attr = op->getAttr(attrName); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" << " * Attribute: " << attrName << "\n" << " * Result: " << attr << "\n\n"); memory[memIndex] = attr.getAsOpaquePointer(); break; } case GetAttributeType: { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); unsigned memIndex = read(); Attribute attr = read(); LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" << " * Result: " << attr.getType() << "\n\n"); memory[memIndex] = attr.getType().getAsOpaquePointer(); break; } case GetDefiningOp: { LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); unsigned memIndex = read(); Value value = read(); Operation *op = value ? value.getDefiningOp() : nullptr; LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" << " * Result: " << *op << "\n\n"); memory[memIndex] = op; break; } case GetOperand0: case GetOperand1: case GetOperand2: case GetOperand3: case GetOperandN: { LLVM_DEBUG({ llvm::dbgs() << "Executing GetOperand" << (opCode == GetOperandN ? Twine("N") : Twine(opCode - GetOperand0)) << ":\n"; }); unsigned index = opCode == GetOperandN ? read() : (opCode - GetOperand0); Operation *op = read(); unsigned memIndex = read(); Value operand = index < op->getNumOperands() ? op->getOperand(index) : Value(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" << " * Index: " << index << "\n" << " * Result: " << operand << "\n\n"); memory[memIndex] = operand.getAsOpaquePointer(); break; } case GetResult0: case GetResult1: case GetResult2: case GetResult3: case GetResultN: { LLVM_DEBUG({ llvm::dbgs() << "Executing GetResult" << (opCode == GetResultN ? Twine("N") : Twine(opCode - GetResult0)) << ":\n"; }); unsigned index = opCode == GetResultN ? read() : (opCode - GetResult0); Operation *op = read(); unsigned memIndex = read(); OpResult result = index < op->getNumResults() ? op->getResult(index) : OpResult(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" << " * Index: " << index << "\n" << " * Result: " << result << "\n\n"); memory[memIndex] = result.getAsOpaquePointer(); break; } case GetValueType: { LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); unsigned memIndex = read(); Value value = read(); LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" << " * Result: " << value.getType() << "\n\n"); memory[memIndex] = value.getType().getAsOpaquePointer(); break; } case IsNotNull: { LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); const void *value = read(); LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); selectJump(value != nullptr); break; } case RecordMatch: { LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); assert(matches && "expected matches to be provided when executing the matcher"); unsigned patternIndex = read(); PatternBenefit benefit = currentPatternBenefits[patternIndex]; const ByteCodeField *dest = &code[read()]; // If the benefit of the pattern is impossible, skip the processing of the // rest of the pattern. if (benefit.isImpossibleToMatch()) { LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); curCodeIt = dest; break; } // Create a fused location containing the locations of each of the // operations used in the match. This will be used as the location for // created operations during the rewrite that don't already have an // explicit location set. unsigned numMatchLocs = read(); SmallVector matchLocs; matchLocs.reserve(numMatchLocs); for (unsigned i = 0; i != numMatchLocs; ++i) matchLocs.push_back(read()->getLoc()); Location matchLoc = rewriter.getFusedLoc(matchLocs); LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" << " * Location: " << matchLoc << "\n\n"); matches->emplace_back(matchLoc, patterns[patternIndex], benefit); readList(matches->back().values); curCodeIt = dest; break; } case ReplaceOp: { LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); Operation *op = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Operation: " << *op << "\n" << " * Values: "; llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n\n"; }); rewriter.replaceOp(op, args); break; } case SwitchAttribute: { LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); Attribute value = read(); ArrayAttr cases = read(); handleSwitch(value, cases); break; } case SwitchOperandCount: { LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); Operation *op = read(); auto cases = read().getValues(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); handleSwitch(op->getNumOperands(), cases); break; } case SwitchOperationName: { LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); OperationName value = read()->getName(); size_t caseCount = read(); // The operation names are stored in-line, so to print them out for // debugging purposes we need to read the array before executing the // switch so that we can display all of the possible values. LLVM_DEBUG({ const ByteCodeField *prevCodeIt = curCodeIt; llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; llvm::interleaveComma( llvm::map_range(llvm::seq(0, caseCount), [&](size_t i) { return read(); }), llvm::dbgs()); llvm::dbgs() << "\n\n"; curCodeIt = prevCodeIt; }); // Try to find the switch value within any of the cases. size_t jumpDest = 0; for (size_t i = 0; i != caseCount; ++i) { if (read() == value) { curCodeIt += (caseCount - i - 1); jumpDest = i + 1; break; } } selectJump(jumpDest); break; } case SwitchResultCount: { LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); Operation *op = read(); auto cases = read().getValues(); LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); handleSwitch(op->getNumResults(), cases); break; } case SwitchType: { LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); Type value = read(); auto cases = read().getAsValueRange(); handleSwitch(value, cases); break; } } } } /// Run the pattern matcher on the given root operation, collecting the matched /// patterns in `matches`. void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl &matches, PDLByteCodeMutableState &state) const { // The first memory slot is always the root operation. state.memory[0] = op; // The matcher function always starts at code address 0. ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, createFunctions, rewriteFunctions); executor.execute(rewriter, &matches); // Order the found matches by benefit. std::stable_sort(matches.begin(), matches.end(), [](const MatchResult &lhs, const MatchResult &rhs) { return lhs.benefit > rhs.benefit; }); } /// Run the rewriter of the given pattern on the root operation `op`. void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const { // The arguments of the rewrite function are stored at the start of the // memory buffer. llvm::copy(match.values, state.memory.begin()); ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, createFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); }