1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple testing pass for slicing.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Support/LLVM.h"
21
22 using namespace mlir;
23
24 /// Create a function with the same signature as the parent function of `op`
25 /// with name being the function name and a `suffix`.
createBackwardSliceFunction(Operation * op,StringRef suffix)26 static LogicalResult createBackwardSliceFunction(Operation *op,
27 StringRef suffix) {
28 FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
29 OpBuilder builder(parentFuncOp);
30 Location loc = op->getLoc();
31 std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
32 FuncOp clonedFuncOp =
33 builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
34 BlockAndValueMapping mapper;
35 builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
36 for (auto arg : enumerate(parentFuncOp.getArguments()))
37 mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
38 llvm::SetVector<Operation *> slice;
39 getBackwardSlice(op, &slice);
40 for (Operation *slicedOp : slice)
41 builder.clone(*slicedOp, mapper);
42 builder.create<ReturnOp>(loc);
43 return success();
44 }
45
46 namespace {
47 /// Pass to test slice generated from slice analysis.
48 struct SliceAnalysisTestPass
49 : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
50 void runOnOperation() override;
51 SliceAnalysisTestPass() = default;
SliceAnalysisTestPass__anona8d01ee20111::SliceAnalysisTestPass52 SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
53 };
54 } // namespace
55
runOnOperation()56 void SliceAnalysisTestPass::runOnOperation() {
57 ModuleOp module = getOperation();
58 auto funcOps = module.getOps<FuncOp>();
59 unsigned opNum = 0;
60 for (auto funcOp : funcOps) {
61 // TODO: For now this is just looking for Linalg ops. It can be generalized
62 // to look for other ops using flags.
63 funcOp.walk([&](Operation *op) {
64 if (!isa<linalg::LinalgOp>(op))
65 return WalkResult::advance();
66 std::string append =
67 std::string("__backward_slice__") + std::to_string(opNum);
68 createBackwardSliceFunction(op, append);
69 opNum++;
70 return WalkResult::advance();
71 });
72 }
73 }
74
75 namespace mlir {
registerSliceAnalysisTestPass()76 void registerSliceAnalysisTestPass() {
77 PassRegistration<SliceAnalysisTestPass> pass(
78 "slice-analysis-test", "Test Slice analysis functionality.");
79 }
80 } // namespace mlir
81