1 //===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_DIALECT_AFFINE_EDSC_BUILDERS_H_
15 #define MLIR_DIALECT_AFFINE_EDSC_BUILDERS_H_
16
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Types.h"
21
22 namespace mlir {
23 namespace edsc {
24
25 /// Creates a perfect nest of affine "for" loops, given the list of lower
26 /// bounds, upper bounds and steps. The three lists are expected to contain the
27 /// same number of elements. Uses the OpBuilder and Location stored in
28 /// ScopedContext and assumes they are non-null. The optional "bodyBuilderFn"
29 /// callback is called to construct the body of the innermost loop and is passed
30 /// the list of loop induction variables, in order from outermost to innermost.
31 /// The function is expected to use the builder and location stored in
32 /// ScopedContext at the moment of the call. The function should not create
33 /// the affine terminator op, which will be added regardless of the
34 /// "bodyBuilderFn" being present.
35 void affineLoopNestBuilder(
36 ValueRange lbs, ValueRange ubs, ArrayRef<int64_t> steps,
37 function_ref<void(ValueRange)> bodyBuilderFn = nullptr);
38
39 /// Creates a single affine "for" loop, iterating from max(lbs) to min(ubs) with
40 /// the given step. Uses the OpBuilder and Location stored in ScopedContext and
41 /// assumes they are non-null. The optional "bodyBuilderFn" callback is called
42 /// to construct the body of the loop and is passed the induction variable. The
43 /// function is expected to use the builder and location stored in ScopedContext
44 /// at the moment of the call. The function should not create the affine
45 /// terminator op, which will be added regardless of the "bodyBuilderFn" being
46 /// present.
47 void affineLoopBuilder(ValueRange lbs, ValueRange ubs, int64_t step,
48 function_ref<void(Value)> bodyBuilderFn = nullptr);
49
50 /// Creates a single affine "for" loop, iterating from max(lbs) to min(ubs) with
51 /// the given step. Uses the OpBuilder and Location stored in ScopedContext and
52 /// assumes they are non-null. "iterArgs" is used to specify the initial values
53 /// of the result affine "for" might yield. The optional "bodyBuilderFn"
54 /// callback is called to construct the body of the loop and is passed the
55 /// induction variable and the iteration arguments. The function is expected to
56 /// use the builder and location stored in ScopedContext at the moment of the
57 /// call. The function will create the affine terminator op in case "iterArgs"
58 /// is empty and "bodyBuilderFn" is not present.
59 void affineLoopBuilder(
60 ValueRange lbs, ValueRange ubs, int64_t step, ValueRange iterArgs,
61 function_ref<void(Value, ValueRange)> bodyBuilderFn = nullptr);
62 namespace op {
63
64 Value operator+(Value lhs, Value rhs);
65 Value operator-(Value lhs, Value rhs);
66 Value operator*(Value lhs, Value rhs);
67 Value operator/(Value lhs, Value rhs);
68 Value operator%(Value lhs, Value rhs);
69 Value floorDiv(Value lhs, Value rhs);
70 Value ceilDiv(Value lhs, Value rhs);
71
72 /// Logical operator overloadings.
73 Value negate(Value value);
74 Value operator&&(Value lhs, Value rhs);
75 Value operator||(Value lhs, Value rhs);
76 Value operator^(Value lhs, Value rhs);
77
78 /// Comparison operator overloadings.
79 Value eq(Value lhs, Value rhs);
80 Value ne(Value lhs, Value rhs);
81 Value slt(Value lhs, Value rhs);
82 Value sle(Value lhs, Value rhs);
83 Value sgt(Value lhs, Value rhs);
84 Value sge(Value lhs, Value rhs);
85 Value ult(Value lhs, Value rhs);
86 Value ule(Value lhs, Value rhs);
87 Value ugt(Value lhs, Value rhs);
88 Value uge(Value lhs, Value rhs);
89
90 } // namespace op
91
92 /// Arithmetic operator overloadings.
93 template <typename Load, typename Store>
94 Value TemplatedIndexedValue<Load, Store>::operator+(Value e) {
95 using op::operator+;
96 return static_cast<Value>(*this) + e;
97 }
98 template <typename Load, typename Store>
99 Value TemplatedIndexedValue<Load, Store>::operator-(Value e) {
100 using op::operator-;
101 return static_cast<Value>(*this) - e;
102 }
103 template <typename Load, typename Store>
104 Value TemplatedIndexedValue<Load, Store>::operator*(Value e) {
105 using op::operator*;
106 return static_cast<Value>(*this) * e;
107 }
108 template <typename Load, typename Store>
109 Value TemplatedIndexedValue<Load, Store>::operator/(Value e) {
110 using op::operator/;
111 return static_cast<Value>(*this) / e;
112 }
113 template <typename Load, typename Store>
114 Value TemplatedIndexedValue<Load, Store>::operator%(Value e) {
115 using op::operator%;
116 return static_cast<Value>(*this) % e;
117 }
118 template <typename Load, typename Store>
119 Value TemplatedIndexedValue<Load, Store>::operator^(Value e) {
120 using op::operator^;
121 return static_cast<Value>(*this) ^ e;
122 }
123
124 /// Assignment-arithmetic operator overloadings.
125 template <typename Load, typename Store>
126 Store TemplatedIndexedValue<Load, Store>::operator+=(Value e) {
127 using op::operator+;
128 return Store(*this + e, getBase(), indices);
129 }
130 template <typename Load, typename Store>
131 Store TemplatedIndexedValue<Load, Store>::operator-=(Value e) {
132 using op::operator-;
133 return Store(*this - e, getBase(), indices);
134 }
135 template <typename Load, typename Store>
136 Store TemplatedIndexedValue<Load, Store>::operator*=(Value e) {
137 using op::operator*;
138 return Store(*this * e, getBase(), indices);
139 }
140 template <typename Load, typename Store>
141 Store TemplatedIndexedValue<Load, Store>::operator/=(Value e) {
142 using op::operator/;
143 return Store(*this / e, getBase(), indices);
144 }
145 template <typename Load, typename Store>
146 Store TemplatedIndexedValue<Load, Store>::operator%=(Value e) {
147 using op::operator%;
148 return Store(*this % e, getBase(), indices);
149 }
150 template <typename Load, typename Store>
151 Store TemplatedIndexedValue<Load, Store>::operator^=(Value e) {
152 using op::operator^;
153 return Store(*this ^ e, getBase(), indices);
154 }
155
156 /// Logical operator overloadings.
157 template <typename Load, typename Store>
158 Value TemplatedIndexedValue<Load, Store>::operator&&(Value e) {
159 using op::operator&&;
160 return static_cast<Value>(*this) && e;
161 }
162 template <typename Load, typename Store>
163 Value TemplatedIndexedValue<Load, Store>::operator||(Value e) {
164 using op::operator||;
165 return static_cast<Value>(*this) || e;
166 }
167
168 /// Comparison operator overloadings.
169 template <typename Load, typename Store>
eq(Value e)170 Value TemplatedIndexedValue<Load, Store>::eq(Value e) {
171 return eq(value, e);
172 }
173 template <typename Load, typename Store>
ne(Value e)174 Value TemplatedIndexedValue<Load, Store>::ne(Value e) {
175 return ne(value, e);
176 }
177 template <typename Load, typename Store>
slt(Value e)178 Value TemplatedIndexedValue<Load, Store>::slt(Value e) {
179 using op::slt;
180 return slt(static_cast<Value>(*this), e);
181 }
182 template <typename Load, typename Store>
sle(Value e)183 Value TemplatedIndexedValue<Load, Store>::sle(Value e) {
184 using op::sle;
185 return sle(static_cast<Value>(*this), e);
186 }
187 template <typename Load, typename Store>
sgt(Value e)188 Value TemplatedIndexedValue<Load, Store>::sgt(Value e) {
189 using op::sgt;
190 return sgt(static_cast<Value>(*this), e);
191 }
192 template <typename Load, typename Store>
sge(Value e)193 Value TemplatedIndexedValue<Load, Store>::sge(Value e) {
194 using op::sge;
195 return sge(static_cast<Value>(*this), e);
196 }
197 template <typename Load, typename Store>
ult(Value e)198 Value TemplatedIndexedValue<Load, Store>::ult(Value e) {
199 using op::ult;
200 return ult(static_cast<Value>(*this), e);
201 }
202 template <typename Load, typename Store>
ule(Value e)203 Value TemplatedIndexedValue<Load, Store>::ule(Value e) {
204 using op::ule;
205 return ule(static_cast<Value>(*this), e);
206 }
207 template <typename Load, typename Store>
ugt(Value e)208 Value TemplatedIndexedValue<Load, Store>::ugt(Value e) {
209 using op::ugt;
210 return ugt(static_cast<Value>(*this), e);
211 }
212 template <typename Load, typename Store>
uge(Value e)213 Value TemplatedIndexedValue<Load, Store>::uge(Value e) {
214 using op::uge;
215 return uge(static_cast<Value>(*this), e);
216 }
217
218 } // namespace edsc
219 } // namespace mlir
220
221 #endif // MLIR_DIALECT_AFFINE_EDSC_BUILDERS_H_
222