• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- BlockSupport.h -------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a number of support types for the Block class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_BLOCK_SUPPORT_H
14 #define MLIR_IR_BLOCK_SUPPORT_H
15 
16 #include "mlir/IR/Value.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/ilist.h"
19 #include "llvm/ADT/ilist_node.h"
20 
21 namespace mlir {
22 class Block;
23 
24 //===----------------------------------------------------------------------===//
25 // Predecessors
26 //===----------------------------------------------------------------------===//
27 
28 /// Implement a predecessor iterator for blocks. This works by walking the use
29 /// lists of the blocks. The entries on this list are the BlockOperands that
30 /// are embedded into terminator operations. From the operand, we can get the
31 /// terminator that contains it, and its parent block is the predecessor.
32 class PredecessorIterator final
33     : public llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
34                                    Block *(*)(BlockOperand &)> {
35   static Block *unwrap(BlockOperand &value);
36 
37 public:
38   using reference = Block *;
39 
40   /// Initializes the operand type iterator to the specified operand iterator.
PredecessorIterator(ValueUseIterator<BlockOperand> it)41   PredecessorIterator(ValueUseIterator<BlockOperand> it)
42       : llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
43                               Block *(*)(BlockOperand &)>(it, &unwrap) {}
PredecessorIterator(BlockOperand * operand)44   explicit PredecessorIterator(BlockOperand *operand)
45       : PredecessorIterator(ValueUseIterator<BlockOperand>(operand)) {}
46 
47   /// Get the successor number in the predecessor terminator.
48   unsigned getSuccessorIndex() const;
49 };
50 
51 //===----------------------------------------------------------------------===//
52 // Successors
53 //===----------------------------------------------------------------------===//
54 
55 /// This class implements the successor iterators for Block.
56 class SuccessorRange final
57     : public llvm::detail::indexed_accessor_range_base<
58           SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
59 public:
60   using RangeBaseT::RangeBaseT;
61   SuccessorRange();
62   SuccessorRange(Block *block);
63   SuccessorRange(Operation *term);
64 
65 private:
66   /// See `llvm::detail::indexed_accessor_range_base` for details.
offset_base(BlockOperand * object,ptrdiff_t index)67   static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
68     return object + index;
69   }
70   /// See `llvm::detail::indexed_accessor_range_base` for details.
dereference_iterator(BlockOperand * object,ptrdiff_t index)71   static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) {
72     return object[index].get();
73   }
74 
75   /// Allow access to `offset_base` and `dereference_iterator`.
76   friend RangeBaseT;
77 };
78 
79 //===----------------------------------------------------------------------===//
80 // BlockRange
81 //===----------------------------------------------------------------------===//
82 
83 /// This class provides an abstraction over the different types of ranges over
84 /// Blocks. In many cases, this prevents the need to explicitly materialize a
85 /// SmallVector/std::vector. This class should be used in places that are not
86 /// suitable for a more derived type (e.g. ArrayRef) or a template range
87 /// parameter.
88 class BlockRange final
89     : public llvm::detail::indexed_accessor_range_base<
90           BlockRange, llvm::PointerUnion<BlockOperand *, Block *const *>,
91           Block *, Block *, Block *> {
92 public:
93   using RangeBaseT::RangeBaseT;
94   BlockRange(ArrayRef<Block *> blocks = llvm::None);
95   BlockRange(SuccessorRange successors);
96   template <typename Arg,
97             typename = typename std::enable_if_t<
98                 std::is_constructible<ArrayRef<Block *>, Arg>::value>>
BlockRange(Arg && arg)99   BlockRange(Arg &&arg)
100       : BlockRange(ArrayRef<Block *>(std::forward<Arg>(arg))) {}
BlockRange(std::initializer_list<Block * > blocks)101   BlockRange(std::initializer_list<Block *> blocks)
102       : BlockRange(ArrayRef<Block *>(blocks)) {}
103 
104 private:
105   /// The owner of the range is either:
106   /// * A pointer to the first element of an array of block operands.
107   /// * A pointer to the first element of an array of Block *.
108   using OwnerT = llvm::PointerUnion<BlockOperand *, Block *const *>;
109 
110   /// See `llvm::detail::indexed_accessor_range_base` for details.
111   static OwnerT offset_base(OwnerT object, ptrdiff_t index);
112 
113   /// See `llvm::detail::indexed_accessor_range_base` for details.
114   static Block *dereference_iterator(OwnerT object, ptrdiff_t index);
115 
116   /// Allow access to `offset_base` and `dereference_iterator`.
117   friend RangeBaseT;
118 };
119 
120 //===----------------------------------------------------------------------===//
121 // Operation Iterators
122 //===----------------------------------------------------------------------===//
123 
124 namespace detail {
125 /// A utility iterator that filters out operations that are not 'OpT'.
126 template <typename OpT, typename IteratorT>
127 class op_filter_iterator
128     : public llvm::filter_iterator<IteratorT, bool (*)(Operation &)> {
filter(Operation & op)129   static bool filter(Operation &op) { return llvm::isa<OpT>(op); }
130 
131 public:
op_filter_iterator(IteratorT it,IteratorT end)132   op_filter_iterator(IteratorT it, IteratorT end)
133       : llvm::filter_iterator<IteratorT, bool (*)(Operation &)>(it, end,
134                                                                 &filter) {}
135 
136   /// Allow implicit conversion to the underlying iterator.
IteratorT()137   operator IteratorT() const { return this->wrapped(); }
138 };
139 
140 /// This class provides iteration over the held operations of a block for a
141 /// specific operation type.
142 template <typename OpT, typename IteratorT>
143 class op_iterator
144     : public llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
145                                    OpT (*)(Operation &)> {
unwrap(Operation & op)146   static OpT unwrap(Operation &op) { return cast<OpT>(op); }
147 
148 public:
149   using reference = OpT;
150 
151   /// Initializes the iterator to the specified filter iterator.
op_iterator(op_filter_iterator<OpT,IteratorT> it)152   op_iterator(op_filter_iterator<OpT, IteratorT> it)
153       : llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
154                               OpT (*)(Operation &)>(it, &unwrap) {}
155 
156   /// Allow implicit conversion to the underlying block iterator.
IteratorT()157   operator IteratorT() const { return this->wrapped(); }
158 };
159 } // end namespace detail
160 } // end namespace mlir
161 
162 namespace llvm {
163 
164 /// Provide support for hashing successor ranges.
165 template <>
166 struct DenseMapInfo<mlir::SuccessorRange> {
167   static mlir::SuccessorRange getEmptyKey() {
168     auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getEmptyKey();
169     return mlir::SuccessorRange(pointer, 0);
170   }
171   static mlir::SuccessorRange getTombstoneKey() {
172     auto *pointer = llvm::DenseMapInfo<mlir::BlockOperand *>::getTombstoneKey();
173     return mlir::SuccessorRange(pointer, 0);
174   }
175   static unsigned getHashValue(mlir::SuccessorRange value) {
176     return llvm::hash_combine_range(value.begin(), value.end());
177   }
178   static bool isEqual(mlir::SuccessorRange lhs, mlir::SuccessorRange rhs) {
179     if (rhs.getBase() == getEmptyKey().getBase())
180       return lhs.getBase() == getEmptyKey().getBase();
181     if (rhs.getBase() == getTombstoneKey().getBase())
182       return lhs.getBase() == getTombstoneKey().getBase();
183     return lhs == rhs;
184   }
185 };
186 
187 //===----------------------------------------------------------------------===//
188 // ilist_traits for Operation
189 //===----------------------------------------------------------------------===//
190 
191 namespace ilist_detail {
192 // Explicitly define the node access for the operation list so that we can
193 // break the dependence on the Operation class in this header. This allows for
194 // operations to have trailing Regions without a circular include
195 // dependence.
196 template <>
197 struct SpecificNodeAccess<
198     typename compute_node_options<::mlir::Operation>::type> : NodeAccess {
199 protected:
200   using OptionsT = typename compute_node_options<mlir::Operation>::type;
201   using pointer = typename OptionsT::pointer;
202   using const_pointer = typename OptionsT::const_pointer;
203   using node_type = ilist_node_impl<OptionsT>;
204 
205   static node_type *getNodePtr(pointer N);
206   static const node_type *getNodePtr(const_pointer N);
207 
208   static pointer getValuePtr(node_type *N);
209   static const_pointer getValuePtr(const node_type *N);
210 };
211 } // end namespace ilist_detail
212 
213 template <> struct ilist_traits<::mlir::Operation> {
214   using Operation = ::mlir::Operation;
215   using op_iterator = simple_ilist<Operation>::iterator;
216 
217   static void deleteNode(Operation *op);
218   void addNodeToList(Operation *op);
219   void removeNodeFromList(Operation *op);
220   void transferNodesFromList(ilist_traits<Operation> &otherList,
221                              op_iterator first, op_iterator last);
222 
223 private:
224   mlir::Block *getContainingBlock();
225 };
226 
227 //===----------------------------------------------------------------------===//
228 // ilist_traits for Block
229 //===----------------------------------------------------------------------===//
230 
231 template <>
232 struct ilist_traits<::mlir::Block> : public ilist_alloc_traits<::mlir::Block> {
233   using Block = ::mlir::Block;
234   using block_iterator = simple_ilist<::mlir::Block>::iterator;
235 
236   void addNodeToList(Block *block);
237   void removeNodeFromList(Block *block);
238   void transferNodesFromList(ilist_traits<Block> &otherList,
239                              block_iterator first, block_iterator last);
240 
241 private:
242   mlir::Region *getParentRegion();
243 };
244 
245 } // end namespace llvm
246 
247 #endif // MLIR_IR_BLOCK_SUPPORT_H
248