1 //===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Utilities. It defines pass independent
10 // methods that help in a reduction pass of the MLIR Reduce tool.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Reducer/ReductionTreeUtils.h"
15
16 #define DEBUG_TYPE "mlir-reduce"
17
18 using namespace mlir;
19
20 /// Update the golden module's content with that of the reduced module.
updateGoldenModule(ModuleOp & golden,ModuleOp reduced)21 void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
22 ModuleOp reduced) {
23 golden.getBody()->clear();
24
25 golden.getBody()->getOperations().splice(golden.getBody()->begin(),
26 reduced.getBody()->getOperations());
27 }
28
29 /// Update the smallest node traversed so far in the reduction tree and
30 /// print the debugging information for the currNode being traversed.
updateSmallestNode(ReductionNode * currNode,ReductionNode * & smallestNode,std::vector<int> path)31 void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
32 ReductionNode *&smallestNode,
33 std::vector<int> path) {
34 LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
35 #ifndef NDEBUG
36 for (int nodeIndex : path)
37 LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
38 #endif
39
40 LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
41 if (currNode->getSize() < smallestNode->getSize()) {
42 LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
43 smallestNode = currNode;
44 }
45 }
46
47 /// Create a transform space index vector based on the specified number of
48 /// indices.
createTransformSpace(ModuleOp module,int numIndices)49 std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
50 int numIndices) {
51 std::vector<bool> transformSpace;
52 for (int i = 0; i < numIndices; ++i)
53 transformSpace.push_back(false);
54
55 return transformSpace;
56 }
57
58 /// Translate section start and end into a vector of ranges specifying the
59 /// section in the non transformed indices in the transform space.
getRanges(std::vector<bool> tSpace,int start,int end)60 static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
61 int start, int end) {
62 std::vector<std::tuple<int, int>> ranges;
63 int rangeStart = 0;
64 int rangeEnd = 0;
65 bool inside = false;
66 int transformableCount = 0;
67
68 for (auto element : llvm::enumerate(tSpace)) {
69 int index = element.index();
70 bool value = element.value();
71
72 if (start <= transformableCount && transformableCount < end) {
73 if (!value && !inside) {
74 inside = true;
75 rangeStart = index;
76 }
77 if (value && inside) {
78 rangeEnd = index;
79 ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
80 inside = false;
81 }
82 }
83
84 if (!value)
85 transformableCount++;
86
87 if (transformableCount == end && inside) {
88 ranges.push_back(std::make_tuple(rangeStart, index + 1));
89 inside = false;
90 break;
91 }
92 }
93
94 return ranges;
95 }
96
97 /// Create the specified number of variants by applying the transform method
98 /// to different ranges of indices in the parent module. The isDeletion boolean
99 /// specifies if the transformation is the deletion of indices.
createVariants(ReductionNode * parent,const Tester & test,int numVariants,llvm::function_ref<void (ModuleOp,int,int)> transform,bool isDeletion)100 void ReductionTreeUtils::createVariants(
101 ReductionNode *parent, const Tester &test, int numVariants,
102 llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
103 std::vector<bool> newTSpace;
104 ModuleOp module = parent->getModule();
105
106 std::vector<bool> parentTSpace = parent->getTransformSpace();
107 int indexCount = parent->transformSpaceSize();
108 std::vector<std::tuple<int, int>> ranges;
109
110 // No new variants can be created.
111 if (indexCount == 0)
112 return;
113
114 // Create a single variant by transforming the unique index.
115 if (indexCount == 1) {
116 ModuleOp variantModule = module.clone();
117 if (isDeletion) {
118 transform(variantModule, 0, 1);
119 } else {
120 ranges = getRanges(parentTSpace, 0, parentTSpace.size());
121 transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
122 }
123
124 new ReductionNode(variantModule, parent, newTSpace);
125
126 return;
127 }
128
129 // Create the specified number of variants.
130 for (int i = 0; i < numVariants; ++i) {
131 ModuleOp variantModule = module.clone();
132 newTSpace = parent->getTransformSpace();
133 int sectionSize = indexCount / numVariants;
134 int sectionStart = sectionSize * i;
135 int sectionEnd = sectionSize * (i + 1);
136
137 if (i == numVariants - 1)
138 sectionEnd = indexCount;
139
140 if (isDeletion)
141 transform(variantModule, sectionStart, sectionEnd);
142
143 ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
144
145 for (auto range : ranges) {
146 int rangeStart = std::get<0>(range);
147 int rangeEnd = std::get<1>(range);
148
149 for (int x = rangeStart; x < rangeEnd; ++x)
150 newTSpace[x] = true;
151
152 if (!isDeletion)
153 transform(variantModule, rangeStart, rangeEnd);
154 }
155
156 // Create Reduction Node in the Reduction tree
157 new ReductionNode(variantModule, parent, newTSpace);
158 }
159 }
160