1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/FunctionSupport.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/BitVector.h"
12
13 using namespace mlir;
14
15 /// Helper to call a callback once on each index in the range
16 /// [0, `totalIndices`), *except* for the indices given in `indices`.
17 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)18 inline void iterateIndicesExcept(unsigned totalIndices,
19 ArrayRef<unsigned> indices,
20 function_ref<void(unsigned)> callback) {
21 llvm::BitVector skipIndices(totalIndices);
22 for (unsigned i : indices)
23 skipIndices.set(i);
24
25 for (unsigned i = 0; i < totalIndices; ++i)
26 if (!skipIndices.test(i))
27 callback(i);
28 }
29
30 //===----------------------------------------------------------------------===//
31 // Function Arguments and Results.
32 //===----------------------------------------------------------------------===//
33
eraseFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,unsigned originalNumArgs,Type newType)34 void mlir::impl::eraseFunctionArguments(Operation *op,
35 ArrayRef<unsigned> argIndices,
36 unsigned originalNumArgs,
37 Type newType) {
38 // There are 3 things that need to be updated:
39 // - Function type.
40 // - Arg attrs.
41 // - Block arguments of entry block.
42 Block &entry = op->getRegion(0).front();
43 SmallString<8> nameBuf;
44
45 // Collect arg attrs to set.
46 SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
47 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
48 newArgAttrs.emplace_back(getArgAttrDict(op, i));
49 });
50
51 // Remove any arg attrs that are no longer needed.
52 for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
53 op->removeAttr(getArgAttrName(i, nameBuf));
54
55 // Set the function type.
56 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
57
58 // Set the new arg attrs, or remove them if empty.
59 for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
60 auto nameAttr = getArgAttrName(i, nameBuf);
61 auto argAttr = newArgAttrs[i];
62 if (argAttr.empty())
63 op->removeAttr(nameAttr);
64 else
65 op->setAttr(nameAttr, argAttr.getDictionary(op->getContext()));
66 }
67
68 // Update the entry block's arguments.
69 entry.eraseArguments(argIndices);
70 }
71
eraseFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,unsigned originalNumResults,Type newType)72 void mlir::impl::eraseFunctionResults(Operation *op,
73 ArrayRef<unsigned> resultIndices,
74 unsigned originalNumResults,
75 Type newType) {
76 // There are 2 things that need to be updated:
77 // - Function type.
78 // - Result attrs.
79 SmallString<8> nameBuf;
80
81 // Collect result attrs to set.
82 SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
83 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
84 newResultAttrs.emplace_back(getResultAttrDict(op, i));
85 });
86
87 // Remove any result attrs that are no longer needed.
88 for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
89 op->removeAttr(getResultAttrName(i, nameBuf));
90
91 // Set the function type.
92 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
93
94 // Set the new result attrs, or remove them if empty.
95 for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
96 auto nameAttr = getResultAttrName(i, nameBuf);
97 auto resultAttr = newResultAttrs[i];
98 if (resultAttr.empty())
99 op->removeAttr(nameAttr);
100 else
101 op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext()));
102 }
103 }
104