1 //===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Reduction Tree Pass class. It provides a framework for 10 // the implementation of different reduction passes in the MLIR Reduce tool. It 11 // allows for custom specification of the variant generation behavior. It 12 // implements methods that define the different possible traversals of the 13 // reduction tree. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H 18 #define MLIR_REDUCER_REDUCTIONTREEPASS_H 19 20 #include <vector> 21 22 #include "PassDetail.h" 23 #include "ReductionNode.h" 24 #include "mlir/Reducer/Passes/OpReducer.h" 25 #include "mlir/Reducer/ReductionTreeUtils.h" 26 #include "mlir/Reducer/Tester.h" 27 28 #define DEBUG_TYPE "mlir-reduce" 29 30 namespace mlir { 31 32 // Defines the traversal method options to be used in the reduction tree 33 /// traversal. 34 enum TraversalMode { SinglePath, Backtrack, MultiPath }; 35 36 /// This class defines the Reduction Tree Pass. It provides a framework to 37 /// to implement a reduction pass using a tree structure to keep track of the 38 /// generated reduced variants. 39 template <typename Reducer, TraversalMode mode> 40 class ReductionTreePass 41 : public ReductionTreeBase<ReductionTreePass<Reducer, mode>> { 42 public: ReductionTreePass(const ReductionTreePass & pass)43 ReductionTreePass(const ReductionTreePass &pass) 44 : ReductionTreeBase<ReductionTreePass<Reducer, mode>>(pass), 45 root(new ReductionNode(pass.root->getModule().clone(), nullptr)), 46 test(pass.test) {} 47 ReductionTreePass(const Tester & test)48 ReductionTreePass(const Tester &test) : test(test) {} 49 50 /// Runs the pass instance in the pass pipeline. runOnOperation()51 void runOnOperation() override { 52 ModuleOp module = this->getOperation(); 53 Reducer reducer; 54 std::vector<bool> transformSpace = reducer.initTransformSpace(module); 55 ReductionNode *reduced; 56 57 this->root = 58 std::make_unique<ReductionNode>(module, nullptr, transformSpace); 59 60 root->measureAndTest(test); 61 62 LLVM_DEBUG(llvm::dbgs() << "\nReduction Tree Pass: " << reducer.getName();); 63 switch (mode) { 64 case SinglePath: 65 LLVM_DEBUG(llvm::dbgs() << " (Single Path)\n";); 66 reduced = singlePathTraversal(); 67 break; 68 default: 69 llvm::report_fatal_error("Traversal method not currently supported."); 70 } 71 72 ReductionTreeUtils::updateGoldenModule(module, 73 reduced->getModule().clone()); 74 } 75 76 private: 77 // Points to the root node in this reduction tree. 78 std::unique_ptr<ReductionNode> root; 79 80 // This object defines the variant generation at each level of the reduction 81 // tree. 82 Reducer reducer; 83 84 // This is used to test the interesting behavior of the reduction nodes in the 85 // tree. 86 const Tester &test; 87 88 /// Traverse the most reduced path in the reduction tree by generating the 89 /// variants at each level using the Reducer parameter's generateVariants 90 /// function. Stops when no new successful variants can be created at the 91 /// current level. singlePathTraversal()92 ReductionNode *singlePathTraversal() { 93 ReductionNode *currNode = root.get(); 94 ReductionNode *smallestNode = currNode; 95 int tSpaceSize = currNode->transformSpaceSize(); 96 std::vector<int> path; 97 98 ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path); 99 100 LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variant: applying the "); 101 LLVM_DEBUG(llvm::dbgs() << "transformation to the entire module\n"); 102 103 reducer.generateVariants(currNode, test, 1); 104 LLVM_DEBUG(llvm::dbgs() << "Testing\n"); 105 currNode->organizeVariants(test); 106 107 if (!currNode->variantsEmpty()) 108 return currNode->getVariant(0); 109 110 while (tSpaceSize != 1) { 111 ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path); 112 113 LLVM_DEBUG(llvm::dbgs() << "\nGenerating 2 variants: applying the "); 114 LLVM_DEBUG(llvm::dbgs() << "transformation to two different sections "); 115 LLVM_DEBUG(llvm::dbgs() << "of transformable indices\n"); 116 117 reducer.generateVariants(currNode, test, 2); 118 LLVM_DEBUG(llvm::dbgs() << "Testing\n"); 119 currNode->organizeVariants(test); 120 121 if (currNode->variantsEmpty()) 122 break; 123 124 currNode = currNode->getVariant(0); 125 tSpaceSize = currNode->transformSpaceSize(); 126 path.push_back(0); 127 } 128 129 if (tSpaceSize == 1) { 130 ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path); 131 132 LLVM_DEBUG(llvm::dbgs() << "\nGenerating 1 variants: applying the "); 133 LLVM_DEBUG(llvm::dbgs() << "transformation to the only transformable"); 134 LLVM_DEBUG(llvm::dbgs() << "index\n"); 135 136 reducer.generateVariants(currNode, test, 1); 137 LLVM_DEBUG(llvm::dbgs() << "Testing\n"); 138 currNode->organizeVariants(test); 139 140 if (!currNode->variantsEmpty()) { 141 currNode = currNode->getVariant(0); 142 path.push_back(0); 143 144 ReductionTreeUtils::updateSmallestNode(currNode, smallestNode, path); 145 } 146 } 147 148 return currNode; 149 } 150 }; 151 152 } // end namespace mlir 153 154 #endif 155