1 //===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the OpReducer class. It defines a variant generator method 10 // with the purpose of producing different variants by eliminating a 11 // parameterizable type of operations from the parent module. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef MLIR_REDUCER_PASSES_OPREDUCER_H 16 #define MLIR_REDUCER_PASSES_OPREDUCER_H 17 18 #include "mlir/IR/Region.h" 19 #include "mlir/Reducer/ReductionNode.h" 20 #include "mlir/Reducer/ReductionTreeUtils.h" 21 #include "mlir/Reducer/Tester.h" 22 23 namespace mlir { 24 25 class OpReducerImpl { 26 public: 27 OpReducerImpl( 28 llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps); 29 30 /// Return the name of this reducer class. 31 StringRef getName(); 32 33 /// Return the initial transformSpace containing the transformable indices. 34 std::vector<bool> initTransformSpace(ModuleOp module); 35 36 /// Generate variants by removing OpType operations from the module in the 37 /// parent and link the variants as childs in the Reduction Tree Pass. 38 void generateVariants(ReductionNode *parent, const Tester &test, 39 int numVariants); 40 41 /// Generate variants by removing OpType operations from the module in the 42 /// parent and link the variants as childs in the Reduction Tree Pass. The 43 /// transform argument defines the function used to remove the OpTpye 44 /// operations in range of indexed OpType operations. 45 void generateVariants(ReductionNode *parent, const Tester &test, 46 int numVariants, 47 llvm::function_ref<void(ModuleOp, int, int)> transform); 48 49 private: 50 llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps; 51 }; 52 53 /// The OpReducer class defines a variant generator method that produces 54 /// multiple variants by eliminating different OpType operations from the 55 /// parent module. 56 template <typename OpType> 57 class OpReducer { 58 public: OpReducer()59 OpReducer() : impl(new OpReducerImpl(getSpecificOps)) {} 60 61 /// Returns the vector of pointer to the OpType operations in the module. getSpecificOps(ModuleOp module)62 static std::vector<Operation *> getSpecificOps(ModuleOp module) { 63 std::vector<Operation *> ops; 64 for (auto op : module.getOps<OpType>()) { 65 ops.push_back(op); 66 } 67 return ops; 68 } 69 70 /// Deletes the OpType operations in the module in the specified index. deleteOps(ModuleOp module,int start,int end)71 static void deleteOps(ModuleOp module, int start, int end) { 72 std::vector<Operation *> opsToRemove; 73 74 for (auto op : enumerate(getSpecificOps(module))) { 75 int index = op.index(); 76 if (index >= start && index < end) 77 opsToRemove.push_back(op.value()); 78 } 79 80 for (Operation *o : opsToRemove) { 81 o->dropAllUses(); 82 o->erase(); 83 } 84 } 85 86 /// Return the name of this reducer class. getName()87 StringRef getName() { return impl->getName(); } 88 89 /// Return the initial transformSpace containing the transformable indices. initTransformSpace(ModuleOp module)90 std::vector<bool> initTransformSpace(ModuleOp module) { 91 return impl->initTransformSpace(module); 92 } 93 94 /// Generate variants by removing OpType operations from the module in the 95 /// parent and link the variants as childs in the Reduction Tree Pass. generateVariants(ReductionNode * parent,const Tester & test,int numVariants)96 void generateVariants(ReductionNode *parent, const Tester &test, 97 int numVariants) { 98 impl->generateVariants(parent, test, numVariants, deleteOps); 99 } 100 101 private: 102 std::unique_ptr<OpReducerImpl> impl; 103 }; 104 105 } // end namespace mlir 106 107 #endif 108