• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ReductionTreeUtils.cpp - Reduction Tree Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Utilities. It defines pass independent
10 // methods that help in a reduction pass of the MLIR Reduce tool.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Reducer/ReductionTreeUtils.h"
15 
16 #define DEBUG_TYPE "mlir-reduce"
17 
18 using namespace mlir;
19 
20 /// Update the golden module's content with that of the reduced module.
updateGoldenModule(ModuleOp & golden,ModuleOp reduced)21 void ReductionTreeUtils::updateGoldenModule(ModuleOp &golden,
22                                             ModuleOp reduced) {
23   golden.getBody()->clear();
24 
25   golden.getBody()->getOperations().splice(golden.getBody()->begin(),
26                                            reduced.getBody()->getOperations());
27 }
28 
29 /// Update the smallest node traversed so far in the reduction tree and
30 /// print the debugging information for the currNode being traversed.
updateSmallestNode(ReductionNode * currNode,ReductionNode * & smallestNode,std::vector<int> path)31 void ReductionTreeUtils::updateSmallestNode(ReductionNode *currNode,
32                                             ReductionNode *&smallestNode,
33                                             std::vector<int> path) {
34   LLVM_DEBUG(llvm::dbgs() << "\nTree Path: root");
35   #ifndef NDEBUG
36   for (int nodeIndex : path)
37     LLVM_DEBUG(llvm::dbgs() << " -> " << nodeIndex);
38   #endif
39 
40   LLVM_DEBUG(llvm::dbgs() << "\nSize (chars): " << currNode->getSize());
41   if (currNode->getSize() < smallestNode->getSize()) {
42     LLVM_DEBUG(llvm::dbgs() << " - new smallest node!");
43     smallestNode = currNode;
44   }
45 }
46 
47 /// Create a transform space index vector based on the specified number of
48 /// indices.
createTransformSpace(ModuleOp module,int numIndices)49 std::vector<bool> ReductionTreeUtils::createTransformSpace(ModuleOp module,
50                                                            int numIndices) {
51   std::vector<bool> transformSpace;
52   for (int i = 0; i < numIndices; ++i)
53     transformSpace.push_back(false);
54 
55   return transformSpace;
56 }
57 
58 /// Translate section start and end into a vector of ranges specifying the
59 /// section in the non transformed indices in the transform space.
getRanges(std::vector<bool> tSpace,int start,int end)60 static std::vector<std::tuple<int, int>> getRanges(std::vector<bool> tSpace,
61                                                    int start, int end) {
62   std::vector<std::tuple<int, int>> ranges;
63   int rangeStart = 0;
64   int rangeEnd = 0;
65   bool inside = false;
66   int transformableCount = 0;
67 
68   for (auto element : llvm::enumerate(tSpace)) {
69     int index = element.index();
70     bool value = element.value();
71 
72     if (start <= transformableCount && transformableCount < end) {
73       if (!value && !inside) {
74         inside = true;
75         rangeStart = index;
76       }
77       if (value && inside) {
78         rangeEnd = index;
79         ranges.push_back(std::make_tuple(rangeStart, rangeEnd));
80         inside = false;
81       }
82     }
83 
84     if (!value)
85       transformableCount++;
86 
87     if (transformableCount == end && inside) {
88       ranges.push_back(std::make_tuple(rangeStart, index + 1));
89       inside = false;
90       break;
91     }
92   }
93 
94   return ranges;
95 }
96 
97 /// Create the specified number of variants by applying the transform method
98 /// to different ranges of indices in the parent module. The isDeletion boolean
99 /// specifies if the transformation is the deletion of indices.
createVariants(ReductionNode * parent,const Tester & test,int numVariants,llvm::function_ref<void (ModuleOp,int,int)> transform,bool isDeletion)100 void ReductionTreeUtils::createVariants(
101     ReductionNode *parent, const Tester &test, int numVariants,
102     llvm::function_ref<void(ModuleOp, int, int)> transform, bool isDeletion) {
103   std::vector<bool> newTSpace;
104   ModuleOp module = parent->getModule();
105 
106   std::vector<bool> parentTSpace = parent->getTransformSpace();
107   int indexCount = parent->transformSpaceSize();
108   std::vector<std::tuple<int, int>> ranges;
109 
110   // No new variants can be created.
111   if (indexCount == 0)
112     return;
113 
114   // Create a single variant by transforming the unique index.
115   if (indexCount == 1) {
116     ModuleOp variantModule = module.clone();
117     if (isDeletion) {
118       transform(variantModule, 0, 1);
119     } else {
120       ranges = getRanges(parentTSpace, 0, parentTSpace.size());
121       transform(variantModule, std::get<0>(ranges[0]), std::get<1>(ranges[0]));
122     }
123 
124     new ReductionNode(variantModule, parent, newTSpace);
125 
126     return;
127   }
128 
129   // Create the specified number of variants.
130   for (int i = 0; i < numVariants; ++i) {
131     ModuleOp variantModule = module.clone();
132     newTSpace = parent->getTransformSpace();
133     int sectionSize = indexCount / numVariants;
134     int sectionStart = sectionSize * i;
135     int sectionEnd = sectionSize * (i + 1);
136 
137     if (i == numVariants - 1)
138       sectionEnd = indexCount;
139 
140     if (isDeletion)
141       transform(variantModule, sectionStart, sectionEnd);
142 
143     ranges = getRanges(parentTSpace, sectionStart, sectionEnd);
144 
145     for (auto range : ranges) {
146       int rangeStart = std::get<0>(range);
147       int rangeEnd = std::get<1>(range);
148 
149       for (int x = rangeStart; x < rangeEnd; ++x)
150         newTSpace[x] = true;
151 
152       if (!isDeletion)
153         transform(variantModule, rangeStart, rangeEnd);
154     }
155 
156     // Create Reduction Node in the Reduction tree
157     new ReductionNode(variantModule, parent, newTSpace);
158   }
159 }
160