• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the reduction nodes which are used to track of the
10 // metadata for a specific generated variant within a reduction pass and are the
11 // building blocks of the reduction tree structure. A reduction tree is used to
12 // keep track of the different generated variants throughout a reduction pass in
13 // the MLIR Reduce tool.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Reducer/ReductionNode.h"
18 
19 using namespace mlir;
20 
21 /// Sets up the metadata and links the node to its parent.
ReductionNode(ModuleOp module,ReductionNode * parent)22 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent)
23     : module(module), evaluated(false) {
24 
25   if (parent != nullptr)
26     parent->linkVariant(this);
27 }
28 
ReductionNode(ModuleOp module,ReductionNode * parent,std::vector<bool> transformSpace)29 ReductionNode::ReductionNode(ModuleOp module, ReductionNode *parent,
30                              std::vector<bool> transformSpace)
31     : module(module), evaluated(false), transformSpace(transformSpace) {
32 
33   if (parent != nullptr)
34     parent->linkVariant(this);
35 }
36 
37 /// Calculates and updates the size and interesting values of the module.
measureAndTest(const Tester & test)38 void ReductionNode::measureAndTest(const Tester &test) {
39   SmallString<128> filepath;
40   int fd;
41 
42   // Print module to temporary file.
43   std::error_code ec =
44       llvm::sys::fs::createTemporaryFile("mlir-reduce", "mlir", fd, filepath);
45 
46   if (ec)
47     llvm::report_fatal_error("Error making unique filename: " + ec.message());
48 
49   llvm::ToolOutputFile out(filepath, fd);
50   module.print(out.os());
51   out.os().close();
52 
53   if (out.os().has_error())
54     llvm::report_fatal_error("Error emitting bitcode to file '" + filepath);
55 
56   size = out.os().tell();
57   interesting = test.isInteresting(filepath);
58   evaluated = true;
59 }
60 
61 /// Returns true if the size and interestingness have been calculated.
isEvaluated() const62 bool ReductionNode::isEvaluated() const { return evaluated; }
63 
64 /// Returns the size in bytes of the module.
getSize() const65 int ReductionNode::getSize() const { return size; }
66 
67 /// Returns true if the module exhibits the interesting behavior.
isInteresting() const68 bool ReductionNode::isInteresting() const { return interesting; }
69 
70 /// Returns the pointers to the child variants.
getVariant(unsigned long index) const71 ReductionNode *ReductionNode::getVariant(unsigned long index) const {
72   if (index < variants.size())
73     return variants[index].get();
74 
75   return nullptr;
76 }
77 
78 /// Returns the number of child variants.
variantsSize() const79 int ReductionNode::variantsSize() const { return variants.size(); }
80 
81 /// Returns true if the child variants vector is empty.
variantsEmpty() const82 bool ReductionNode::variantsEmpty() const { return variants.empty(); }
83 
84 /// Link a child variant node.
linkVariant(ReductionNode * newVariant)85 void ReductionNode::linkVariant(ReductionNode *newVariant) {
86   std::unique_ptr<ReductionNode> ptrVariant(newVariant);
87   variants.push_back(std::move(ptrVariant));
88 }
89 
90 /// Sort the child variants and remove the uninteresting ones.
organizeVariants(const Tester & test)91 void ReductionNode::organizeVariants(const Tester &test) {
92   // Ensure all variants are evaluated.
93   for (auto &var : variants)
94     if (!var->isEvaluated())
95       var->measureAndTest(test);
96 
97   // Sort variants by interestingness and size.
98   llvm::array_pod_sort(
99       variants.begin(), variants.end(), [](const auto *lhs, const auto *rhs) {
100         if (lhs->get()->isInteresting() && !rhs->get()->isInteresting())
101           return 0;
102 
103         if (!lhs->get()->isInteresting() && rhs->get()->isInteresting())
104           return 1;
105 
106         return (lhs->get()->getSize(), rhs->get()->getSize());
107       });
108 
109   int interestingCount = 0;
110   for (auto &var : variants) {
111     if (var->isInteresting()) {
112       ++interestingCount;
113     } else {
114       break;
115     }
116   }
117 
118   // Remove uninteresting variants.
119   variants.resize(interestingCount);
120 }
121 
122 /// Returns the number of non transformed indices.
transformSpaceSize()123 int ReductionNode::transformSpaceSize() {
124   return std::count(transformSpace.begin(), transformSpace.end(), false);
125 }
126 
127 /// Returns a vector of the transformable indices in the Module.
getTransformSpace()128 const std::vector<bool> ReductionNode::getTransformSpace() {
129   return transformSpace;
130 }
131