1 //===- BufferAliasAnalysis.cpp - Buffer alias analysis for MLIR -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/BufferAliasAnalysis.h"
10
11 #include "mlir/Interfaces/ControlFlowInterfaces.h"
12 #include "mlir/Interfaces/ViewLikeInterface.h"
13 #include "llvm/ADT/SetOperations.h"
14
15 using namespace mlir;
16
17 /// Constructs a new alias analysis using the op provided.
BufferAliasAnalysis(Operation * op)18 BufferAliasAnalysis::BufferAliasAnalysis(Operation *op) { build(op); }
19
20 /// Find all immediate and indirect aliases this value could potentially
21 /// have. Note that the resulting set will also contain the value provided as
22 /// it is an alias of itself.
23 BufferAliasAnalysis::ValueSetT
resolve(Value rootValue) const24 BufferAliasAnalysis::resolve(Value rootValue) const {
25 ValueSetT result;
26 SmallVector<Value, 8> queue;
27 queue.push_back(rootValue);
28 while (!queue.empty()) {
29 Value currentValue = queue.pop_back_val();
30 if (result.insert(currentValue).second) {
31 auto it = aliases.find(currentValue);
32 if (it != aliases.end()) {
33 for (Value aliasValue : it->second)
34 queue.push_back(aliasValue);
35 }
36 }
37 }
38 return result;
39 }
40
41 /// Removes the given values from all alias sets.
remove(const SmallPtrSetImpl<Value> & aliasValues)42 void BufferAliasAnalysis::remove(const SmallPtrSetImpl<Value> &aliasValues) {
43 for (auto &entry : aliases)
44 llvm::set_subtract(entry.second, aliasValues);
45 }
46
47 /// This function constructs a mapping from values to its immediate aliases.
48 /// It iterates over all blocks, gets their predecessors, determines the
49 /// values that will be passed to the corresponding block arguments and
50 /// inserts them into the underlying map. Furthermore, it wires successor
51 /// regions and branch-like return operations from nested regions.
build(Operation * op)52 void BufferAliasAnalysis::build(Operation *op) {
53 // Registers all aliases of the given values.
54 auto registerAliases = [&](auto values, auto aliases) {
55 for (auto entry : llvm::zip(values, aliases))
56 this->aliases[std::get<0>(entry)].insert(std::get<1>(entry));
57 };
58
59 // Add additional aliases created by view changes to the alias list.
60 op->walk([&](ViewLikeOpInterface viewInterface) {
61 aliases[viewInterface.getViewSource()].insert(viewInterface->getResult(0));
62 });
63
64 // Query all branch interfaces to link block argument aliases.
65 op->walk([&](BranchOpInterface branchInterface) {
66 Block *parentBlock = branchInterface->getBlock();
67 for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end();
68 it != e; ++it) {
69 // Query the branch op interface to get the successor operands.
70 auto successorOperands =
71 branchInterface.getSuccessorOperands(it.getIndex());
72 if (!successorOperands.hasValue())
73 continue;
74 // Build the actual mapping of values to their immediate aliases.
75 registerAliases(successorOperands.getValue(), (*it)->getArguments());
76 }
77 });
78
79 // Query the RegionBranchOpInterface to find potential successor regions.
80 op->walk([&](RegionBranchOpInterface regionInterface) {
81 // Extract all entry regions and wire all initial entry successor inputs.
82 SmallVector<RegionSuccessor, 2> entrySuccessors;
83 regionInterface.getSuccessorRegions(/*index=*/llvm::None, entrySuccessors);
84 for (RegionSuccessor &entrySuccessor : entrySuccessors) {
85 // Wire the entry region's successor arguments with the initial
86 // successor inputs.
87 assert(entrySuccessor.getSuccessor() &&
88 "Invalid entry region without an attached successor region");
89 registerAliases(regionInterface.getSuccessorEntryOperands(
90 entrySuccessor.getSuccessor()->getRegionNumber()),
91 entrySuccessor.getSuccessorInputs());
92 }
93
94 // Wire flow between regions and from region exits.
95 for (Region ®ion : regionInterface->getRegions()) {
96 // Iterate over all successor region entries that are reachable from the
97 // current region.
98 SmallVector<RegionSuccessor, 2> successorRegions;
99 regionInterface.getSuccessorRegions(region.getRegionNumber(),
100 successorRegions);
101 for (RegionSuccessor &successorRegion : successorRegions) {
102 // Iterate over all immediate terminator operations and wire the
103 // successor inputs with the operands of each terminator.
104 for (Block &block : region) {
105 for (Operation &operation : block) {
106 if (operation.hasTrait<OpTrait::ReturnLike>())
107 registerAliases(operation.getOperands(),
108 successorRegion.getSuccessorInputs());
109 }
110 }
111 }
112 }
113 });
114 }
115