• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionSupport.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/BitVector.h"
12 
13 using namespace mlir;
14 
15 /// Helper to call a callback once on each index in the range
16 /// [0, `totalIndices`), *except* for the indices given in `indices`.
17 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)18 inline void iterateIndicesExcept(unsigned totalIndices,
19                                  ArrayRef<unsigned> indices,
20                                  function_ref<void(unsigned)> callback) {
21   llvm::BitVector skipIndices(totalIndices);
22   for (unsigned i : indices)
23     skipIndices.set(i);
24 
25   for (unsigned i = 0; i < totalIndices; ++i)
26     if (!skipIndices.test(i))
27       callback(i);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Function Arguments and Results.
32 //===----------------------------------------------------------------------===//
33 
eraseFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,unsigned originalNumArgs,Type newType)34 void mlir::impl::eraseFunctionArguments(Operation *op,
35                                         ArrayRef<unsigned> argIndices,
36                                         unsigned originalNumArgs,
37                                         Type newType) {
38   // There are 3 things that need to be updated:
39   // - Function type.
40   // - Arg attrs.
41   // - Block arguments of entry block.
42   Block &entry = op->getRegion(0).front();
43   SmallString<8> nameBuf;
44 
45   // Collect arg attrs to set.
46   SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
47   iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
48     newArgAttrs.emplace_back(getArgAttrDict(op, i));
49   });
50 
51   // Remove any arg attrs that are no longer needed.
52   for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
53     op->removeAttr(getArgAttrName(i, nameBuf));
54 
55   // Set the function type.
56   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
57 
58   // Set the new arg attrs, or remove them if empty.
59   for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
60     auto nameAttr = getArgAttrName(i, nameBuf);
61     auto argAttr = newArgAttrs[i];
62     if (argAttr.empty())
63       op->removeAttr(nameAttr);
64     else
65       op->setAttr(nameAttr, argAttr.getDictionary(op->getContext()));
66   }
67 
68   // Update the entry block's arguments.
69   entry.eraseArguments(argIndices);
70 }
71 
eraseFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,unsigned originalNumResults,Type newType)72 void mlir::impl::eraseFunctionResults(Operation *op,
73                                       ArrayRef<unsigned> resultIndices,
74                                       unsigned originalNumResults,
75                                       Type newType) {
76   // There are 2 things that need to be updated:
77   // - Function type.
78   // - Result attrs.
79   SmallString<8> nameBuf;
80 
81   // Collect result attrs to set.
82   SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
83   iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
84     newResultAttrs.emplace_back(getResultAttrDict(op, i));
85   });
86 
87   // Remove any result attrs that are no longer needed.
88   for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
89     op->removeAttr(getResultAttrName(i, nameBuf));
90 
91   // Set the function type.
92   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
93 
94   // Set the new result attrs, or remove them if empty.
95   for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
96     auto nameAttr = getResultAttrName(i, nameBuf);
97     auto resultAttr = newResultAttrs[i];
98     if (resultAttr.empty())
99       op->removeAttr(nameAttr);
100     else
101       op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext()));
102   }
103 }
104