//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/NestedMatcher.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; llvm::BumpPtrAllocator *&NestedMatch::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedMatch NestedMatch::build(Operation *operation, ArrayRef nestedMatches) { auto *result = allocator()->Allocate(); auto *children = allocator()->Allocate(nestedMatches.size()); std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); new (result) NestedMatch(); result->matchedOperation = operation; result->matchedChildren = ArrayRef(children, nestedMatches.size()); return *result; } llvm::BumpPtrAllocator *&NestedPattern::allocator() { thread_local llvm::BumpPtrAllocator *allocator = nullptr; return allocator; } NestedPattern::NestedPattern(ArrayRef nested, FilterFunctionType filter) : nestedPatterns(), filter(filter), skip(nullptr) { if (!nested.empty()) { auto *newNested = allocator()->Allocate(nested.size()); std::uninitialized_copy(nested.begin(), nested.end(), newNested); nestedPatterns = ArrayRef(newNested, nested.size()); } } unsigned NestedPattern::getDepth() const { if (nestedPatterns.empty()) { return 1; } unsigned depth = 0; for (auto &c : nestedPatterns) { depth = std::max(depth, c.getDepth()); } return depth + 1; } /// Matches a single operation in the following way: /// 1. checks the kind of operation against the matcher, if different then /// there is no match; /// 2. calls the customizable filter function to refine the single operation /// match with extra semantic constraints; /// 3. if all is good, recursively matches the nested patterns; /// 4. if all nested match then the single operation matches too and is /// appended to the list of matches; /// 5. TODO: Optionally applies actions (lambda), in which case we will want /// to traverse in post-order DFS to avoid invalidating iterators. void NestedPattern::matchOne(Operation *op, SmallVectorImpl *matches) { if (skip == op) { return; } // Local custom filter function if (!filter(*op)) { return; } if (nestedPatterns.empty()) { SmallVector nestedMatches; matches->push_back(NestedMatch::build(op, nestedMatches)); return; } // Take a copy of each nested pattern so we can match it. for (auto nestedPattern : nestedPatterns) { SmallVector nestedMatches; // Skip elem in the walk immediately following. Without this we would // essentially need to reimplement walk here. nestedPattern.skip = op; nestedPattern.match(op, &nestedMatches); // If we could not match even one of the specified nestedPattern, early exit // as this whole branch is not a match. if (nestedMatches.empty()) { return; } matches->push_back(NestedMatch::build(op, nestedMatches)); } } static bool isAffineForOp(Operation &op) { return isa(op); } static bool isAffineIfOp(Operation &op) { return isa(op); } namespace mlir { namespace matcher { NestedPattern Op(FilterFunctionType filter) { return NestedPattern({}, filter); } NestedPattern If(NestedPattern child) { return NestedPattern(child, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { return NestedPattern(child, [filter](Operation &op) { return isAffineIfOp(op) && filter(op); }); } NestedPattern If(ArrayRef nested) { return NestedPattern(nested, isAffineIfOp); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { return NestedPattern(nested, [filter](Operation &op) { return isAffineIfOp(op) && filter(op); }); } NestedPattern For(NestedPattern child) { return NestedPattern(child, isAffineForOp); } NestedPattern For(FilterFunctionType filter, NestedPattern child) { return NestedPattern( child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); } NestedPattern For(ArrayRef nested) { return NestedPattern(nested, isAffineForOp); } NestedPattern For(FilterFunctionType filter, ArrayRef nested) { return NestedPattern( nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); } bool isLoadOrStore(Operation &op) { return isa(op); } } // end namespace matcher } // end namespace mlir