1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Reducer/ReductionNode.h"
18
19 using namespace mlir;
20
21 /// Sets up the metadata and links the node to its parent.
ReductionNode(ModuleOp module,ReductionNode * parent)22 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
23 : module(module), evaluated(false) {
24
25 if (parent != nullptr)
26 parent->linkVariant(this);
27 }
28
ReductionNode(ModuleOp module,ReductionNode * parent,std::vector<bool> transformSpace)29 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
30 std::vector<bool> transformSpace)
31 : module(module), evaluated(false), transformSpace(transformSpace) {
32
33 if (parent != nullptr)
34 parent->linkVariant(this);
35 }
36
37 /// Calculates and updates the size and interesting values of the module.
measureAndTest(const Tester & test)38 void ReductionNode::measureAndTest(const Tester &test) {
39 SmallString<128> filepath;
40 int fd;
41
42 // Print module to temporary file.
43 std::error_code ec =
44 llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
45
46 if (ec)
47 llvm::report_fatal_error("Error making unique filename: " + ec.message());
48
49 llvm::ToolOutputFile out(filepath, fd);
50 module.print(out.os());
51 out.os().close();
52
53 if (out.os().has_error())
54 llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
55
56 size = out.os().tell();
57 interesting = test.isInteresting(filepath);
58 evaluated = true;
59 }
60
61 /// Returns true if the size and interestingness have been calculated.
isEvaluated() const62 bool ReductionNode::isEvaluated() const { return evaluated; }
63
64 /// Returns the size in bytes of the module.
getSize() const65 int ReductionNode::getSize() const { return size; }
66
67 /// Returns true if the module exhibits the interesting behavior.
isInteresting() const68 bool ReductionNode::isInteresting() const { return interesting; }
69
70 /// Returns the pointers to the child variants.
getVariant(unsigned long index) const71 ReductionNode *ReductionNode::getVariant(unsigned long index) const {
72 if (index < variants.size())
73 return variants[index].get();
74
75 return nullptr;
76 }
77
78 /// Returns the number of child variants.
variantsSize() const79 int ReductionNode::variantsSize() const { return variants.size(); }
80
81 /// Returns true if the child variants vector is empty.
variantsEmpty() const82 bool ReductionNode::variantsEmpty() const { return variants.empty(); }
83
84 /// Link a child variant node.
linkVariant(ReductionNode * newVariant)85 void ReductionNode::linkVariant(ReductionNode *newVariant) {
86 std::unique_ptr<ReductionNode> ptrVariant(newVariant);
87 variants.push_back(std::move(ptrVariant));
88 }
89
90 /// Sort the child variants and remove the uninteresting ones.
organizeVariants(const Tester & test)91 void ReductionNode::organizeVariants(const Tester &test) {
92 // Ensure all variants are evaluated.
93 for (auto &var : variants)
94 if (!var->isEvaluated())
95 var->measureAndTest(test);
96
97 // Sort variants by interestingness and size.
98 llvm::array_pod_sort(
99 variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
100 if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
101 return 0;
102
103 if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
104 return 1;
105
106 return (lhs->get()->getSize(), rhs->get()->getSize());
107 });
108
109 int interestingCount = 0;
110 for (auto &var : variants) {
111 if (var->isInteresting()) {
112 ++interestingCount;
113 } else {
114 break;
115 }
116 }
117
118 // Remove uninteresting variants.
119 variants.resize(interestingCount);
120 }
121
122 /// Returns the number of non transformed indices.
transformSpaceSize()123 int ReductionNode::transformSpaceSize() {
124 return std::count(transformSpace.begin(), transformSpace.end(), false);
125 }
126
127 /// Returns a vector of the transformable indices in the Module.
getTransformSpace()128 const std::vector<bool> ReductionNode::getTransformSpace() {
129 return transformSpace;
130 }
131