• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the OpReducer class. It defines a variant generator method
10 // with the purpose of producing different variants by eliminating a
11 // parameterizable type of operations from the  parent module.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
16 #define MLIR_REDUCER_PASSES_OPREDUCER_H
17 
18 #include "mlir/IR/Region.h"
19 #include "mlir/Reducer/ReductionNode.h"
20 #include "mlir/Reducer/ReductionTreeUtils.h"
21 #include "mlir/Reducer/Tester.h"
22 
23 namespace mlir {
24 
25 class OpReducerImpl {
26 public:
27   OpReducerImpl(
28       llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps);
29 
30   /// Return the name of this reducer class.
31   StringRef getName();
32 
33   /// Return the initial transformSpace containing the transformable indices.
34   std::vector<bool> initTransformSpace(ModuleOp module);
35 
36   /// Generate variants by removing OpType operations from the module in the
37   /// parent and link the variants as childs in the Reduction Tree Pass.
38   void generateVariants(ReductionNode *parent, const Tester &test,
39                         int numVariants);
40 
41   /// Generate variants by removing OpType operations from the module in the
42   /// parent and link the variants as childs in the Reduction Tree Pass. The
43   /// transform argument defines the function used to remove the OpTpye
44   /// operations in range of indexed OpType operations.
45   void generateVariants(ReductionNode *parent, const Tester &test,
46                         int numVariants,
47                         llvm::function_ref<void(ModuleOp, int, int)> transform);
48 
49 private:
50   llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps;
51 };
52 
53 /// The OpReducer class defines a variant generator method that produces
54 /// multiple variants by eliminating different OpType operations from the
55 /// parent module.
56 template <typename OpType>
57 class OpReducer {
58 public:
OpReducer()59   OpReducer() : impl(new OpReducerImpl(getSpecificOps)) {}
60 
61   /// Returns the vector of pointer to the OpType operations in the module.
getSpecificOps(ModuleOp module)62   static std::vector<Operation *> getSpecificOps(ModuleOp module) {
63     std::vector<Operation *> ops;
64     for (auto op : module.getOps<OpType>()) {
65       ops.push_back(op);
66     }
67     return ops;
68   }
69 
70   /// Deletes the OpType operations in the module in the specified index.
deleteOps(ModuleOp module,int start,int end)71   static void deleteOps(ModuleOp module, int start, int end) {
72     std::vector<Operation *> opsToRemove;
73 
74     for (auto op : enumerate(getSpecificOps(module))) {
75       int index = op.index();
76       if (index >= start && index < end)
77         opsToRemove.push_back(op.value());
78     }
79 
80     for (Operation *o : opsToRemove) {
81       o->dropAllUses();
82       o->erase();
83     }
84   }
85 
86   /// Return the name of this reducer class.
getName()87   StringRef getName() { return impl->getName(); }
88 
89   /// Return the initial transformSpace containing the transformable indices.
initTransformSpace(ModuleOp module)90   std::vector<bool> initTransformSpace(ModuleOp module) {
91     return impl->initTransformSpace(module);
92   }
93 
94   /// Generate variants by removing OpType operations from the module in the
95   /// parent and link the variants as childs in the Reduction Tree Pass.
generateVariants(ReductionNode * parent,const Tester & test,int numVariants)96   void generateVariants(ReductionNode *parent, const Tester &test,
97                         int numVariants) {
98     impl->generateVariants(parent, test, numVariants, deleteOps);
99   }
100 
101 private:
102   std::unique_ptr<OpReducerImpl> impl;
103 };
104 
105 } // end namespace mlir
106 
107 #endif
108