1 //===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
14 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
15
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21
22 namespace mlir {
23 class AffineForOp;
24 class BlockArgument;
25
26 namespace scf {
27 class ParallelOp;
28 } // namespace scf
29
30 namespace edsc {
defaultRegionBuilder(ValueRange args)31 inline void defaultRegionBuilder(ValueRange args) {}
32
33 /// Build a `linalg.generic` op with the specified `inputs`, `outputBuffers`,
34 /// `initTensors`, `resultTensorsTypes` and `region`.
35 ///
36 /// `otherValues` and `otherAttributes` may be passed and will be appended as
37 /// operands and attributes respectively.
38 ///
39 /// Prerequisites:
40 /// =============
41 ///
42 /// 1. `inputs` may contain StructuredIndexed that capture either buffer or
43 /// tensor values.
44 /// 2. `outputsBuffers` may contain StructuredIndexed that capture buffer
45 /// values.
46 /// 3. `initTensors` contain tensor values, without indexing maps.
47 /// 4. `resultTensorTypes` may contain StructuredIndexed that capture return
48 /// tensor types.
49 Operation *makeGenericLinalgOp(
50 ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
51 ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
52 ArrayRef<StructuredIndexed> resultTensorTypes,
53 function_ref<void(ValueRange)> regionBuilder = defaultRegionBuilder,
54 ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
55
56 namespace ops {
57 using edsc::StructuredIndexed;
58
59 //===----------------------------------------------------------------------===//
60 // EDSC builders for linalg generic operations.
61 //===----------------------------------------------------------------------===//
62
63 /// Build the body of a region to compute a scalar multiply, under the current
64 /// ScopedContext, at the current insert point.
65 void mulRegionBuilder(ValueRange args);
66
67 /// Build the body of a region to compute a scalar multiply-accumulate, under
68 /// the current ScopedContext, at the current insert point.
69 void macRegionBuilder(ValueRange args);
70
71 /// TODO: In the future we should tie these implementations to something in
72 /// Tablegen that generates the proper interfaces and the proper sugared named
73 /// ops.
74
75 /// Build a linalg.pointwise, under the current ScopedContext, at the current
76 /// insert point, that computes:
77 /// ```
78 /// (i0, ..., in) = (par, ..., par)
79 /// |
80 /// | O...(some_subset...(i0, ..., in)) =
81 /// | some_pointwise_func...(I...(some_other_subset...(i0, ..., in)))
82 /// ```
83 ///
84 /// This is a very generic entry point that can be configured in many ways to
85 /// build a perfect loop nest of parallel loops with arbitrarily complex
86 /// innermost loop code and whatever (explicit) broadcast semantics.
87 ///
88 /// This can be used with both out-of-place and in-place semantics.
89 /// The client is responsible for ensuring the region operations are compatible
90 /// with in-place semantics and parallelism.
91
92 /// Unary pointwise operation (with broadcast) entry point.
93 using UnaryPointwiseOpBuilder = function_ref<Value(Value)>;
94 Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp,
95 StructuredIndexed I, StructuredIndexed O);
96
97 /// Build a linalg.pointwise with all `parallel` iterators and a region that
98 /// computes `O = tanh(I)`. The client is responsible for specifying the proper
99 /// indexings when creating the StructuredIndexed.
100 Operation *linalg_generic_pointwise_tanh(StructuredIndexed I,
101 StructuredIndexed O);
102
103 /// Binary pointwise operation (with broadcast) entry point.
104 using BinaryPointwiseOpBuilder = function_ref<Value(Value, Value)>;
105 Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp,
106 StructuredIndexed I1, StructuredIndexed I2,
107 StructuredIndexed O);
108
109 /// Build a linalg.pointwise with all `parallel` iterators and a region that
110 /// computes `O = I1 + I2`. The client is responsible for specifying the proper
111 /// indexings when creating the StructuredIndexed.
112 Operation *linalg_generic_pointwise_add(StructuredIndexed I1,
113 StructuredIndexed I2,
114 StructuredIndexed O);
115
116 /// Build a linalg.pointwise with all `parallel` iterators and a region that
117 /// computes `O = max(I1, I2)`. The client is responsible for specifying the
118 /// proper indexings when creating the StructuredIndexed.
119 Operation *linalg_generic_pointwise_max(StructuredIndexed I1,
120 StructuredIndexed I2,
121 StructuredIndexed O);
122
123 // TODO: Implement more useful pointwise operations on a per-need basis.
124
125 using MatmulRegionBuilder = function_ref<void(ValueRange args)>;
126
127 /// Build a linalg.generic, under the current ScopedContext, at the current
128 /// insert point, that computes:
129 /// ```
130 /// (m, n, k) = (par, par, seq)
131 /// |
132 /// | C(m, n) += A(m, k) * B(k, n)
133 /// ```
134 Operation *
135 linalg_generic_matmul(Value vA, Value vB, Value vC,
136 MatmulRegionBuilder regionBuilder = macRegionBuilder);
137
138 /// Build a linalg.generic, under the current ScopedContext, at the current
139 /// insert point, that computes:
140 /// ```
141 /// (m, n, k) = (par, par, seq)
142 /// |
143 /// | D(m, n) = C(m, n) + sum_k(A(m, k) * B(k, n))
144 /// ```
145 /// and returns the tensor `D`.
146 Operation *
147 linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD,
148 MatmulRegionBuilder regionBuilder = macRegionBuilder);
149
150 template <typename Container>
151 Operation *
152 linalg_generic_matmul(Container values,
153 MatmulRegionBuilder regionBuilder = macRegionBuilder) {
154 assert(values.size() == 3 && "Expected exactly 3 values");
155 return linalg_generic_matmul(values[0], values[1], values[2], regionBuilder);
156 }
157
158 /// Build a linalg.generic, under the current ScopedContext, at the current
159 /// insert point, that computes:
160 /// ```
161 /// (batch, f, [h, w, ...], [kh, kw, ...], c) =
162 /// | (par, par, [par, par, ...], [red, red, ...], red)
163 /// |
164 /// | O(batch, [h, w, ...], f) +=
165 /// | I(batch,
166 /// | [
167 /// | stride[0] * h + dilations[0] * kh,
168 /// | stride[1] * w + dilations[1] * kw, ...
169 /// ],
170 /// | c)
171 /// | *
172 /// | W([kh, kw, ...], c, f)
173 /// ```
174 /// If `dilations` or `strides` are left empty, the default value of `1` is used
175 /// along each relevant dimension.
176 ///
177 /// For now `...` must be empty (i.e. only 2-D convolutions are supported).
178 ///
179 // TODO: Extend convolution rank with some template magic.
180 Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO,
181 ArrayRef<int> strides = {},
182 ArrayRef<int> dilations = {});
183
184 template <typename Container>
185 Operation *linalg_generic_conv_nhwc(Container values,
186 ArrayRef<int> strides = {},
187 ArrayRef<int> dilations = {}) {
188 assert(values.size() == 3 && "Expected exactly 3 values");
189 return linalg_generic_conv_nhwc(values[0], values[1], values[2], strides,
190 dilations);
191 }
192
193 /// Build a linalg.generic, under the current ScopedContext, at the current
194 /// insert point, that computes:
195 /// ```
196 /// (batch, dm, c, [h, w, ...], [kh, kw, ...]) =
197 /// | (par, par, par, [par, par, ...], [red, red, ...])
198 /// |
199 /// | O(batch, [h, w, ...], c * depth_multiplier) +=
200 /// | I(batch,
201 /// | [
202 /// | stride[0] * h + dilations[0] * kh,
203 /// | stride[1] * w + dilations[1] * kw, ...
204 /// ],
205 /// | c)
206 /// | *
207 /// | W([kh, kw, ...], c, depth_multiplier)
208 /// ```
209 /// If `dilations` or `strides` are left empty, the default value of `1` is used
210 /// along each relevant dimension.
211 ///
212 /// For now `...` must be empty (i.e. only 2-D convolutions are supported).
213 ///
214 // TODO: Extend convolution rank with some template magic.
215 Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO,
216 int depth_multiplier = 1,
217 ArrayRef<int> strides = {},
218 ArrayRef<int> dilations = {});
219
220 template <typename Container>
221 Operation *linalg_generic_dilated_conv_nhwc(Container values,
222 int depth_multiplier,
223 ArrayRef<int> strides = {},
224 ArrayRef<int> dilations = {}) {
225 assert(values.size() == 3 && "Expected exactly 3 values");
226 return linalg_generic_dilated_conv_nhwc(values[0], values[1], values[2],
227 depth_multiplier, strides, dilations);
228 }
229
230 } // namespace ops
231 } // namespace edsc
232 } // namespace mlir
233
234 #endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
235