1 //===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the OpReducer class. It defines a variant generator method
10 // with the purpose of producing different variants by eliminating a
11 // parameterizable type of operations from the parent module.
12 //
13 //===----------------------------------------------------------------------===//
14 #include "mlir/Reducer/Passes/OpReducer.h"
15
16 using namespace mlir;
17
OpReducerImpl(llvm::function_ref<std::vector<Operation * > (ModuleOp)> getSpecificOps)18 OpReducerImpl::OpReducerImpl(
19 llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps)
20 : getSpecificOps(getSpecificOps) {}
21
22 /// Return the name of this reducer class.
getName()23 StringRef OpReducerImpl::getName() {
24 return StringRef("High Level Operation Reduction");
25 }
26
27 /// Return the initial transformSpace containing the transformable indices.
initTransformSpace(ModuleOp module)28 std::vector<bool> OpReducerImpl::initTransformSpace(ModuleOp module) {
29 auto ops = getSpecificOps(module);
30 int numOps = std::distance(ops.begin(), ops.end());
31 return ReductionTreeUtils::createTransformSpace(module, numOps);
32 }
33
34 /// Generate variants by removing opType operations from the module in the
35 /// parent and link the variants as childs in the Reduction Tree Pass.
generateVariants(ReductionNode * parent,const Tester & test,int numVariants,llvm::function_ref<void (ModuleOp,int,int)> transform)36 void OpReducerImpl::generateVariants(
37 ReductionNode *parent, const Tester &test, int numVariants,
38 llvm::function_ref<void(ModuleOp, int, int)> transform) {
39 ReductionTreeUtils::createVariants(parent, test, numVariants, transform,
40 true);
41 }
42