1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28
29 namespace mlir {
30 namespace lmhlo {
31 namespace {
32
33 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
34 // single output buffer to make it compatible with `operands` that have element
35 // types of the respective buffers. Returns the computed value.
36 //
37 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
38 // with signature:
39 // ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
40 // <LHLO_ops>
41 //
42 // inserts necessary alloc and store ops to compute and return result that has
43 // `i1` type.
ApplySingleResultLhloCode(Location loc,ValueRange operands,Block * lhlo_block,OpBuilder * b)44 Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
45 Block* lhlo_block, OpBuilder* b) {
46 SmallVector<Value, 2> arg_bufs;
47 for (auto arg_type : lhlo_block->getArgumentTypes()) {
48 arg_bufs.push_back(
49 b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
50 }
51 for (auto operand : llvm::enumerate(operands)) {
52 b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
53 }
54 // Clone the ops from `lhlo_block`.
55 BlockAndValueMapping mapping;
56 mapping.map(lhlo_block->getArguments(), arg_bufs);
57 for (auto& nested : lhlo_block->without_terminator()) {
58 auto clone = b->clone(nested, mapping);
59 mapping.map(nested.getResults(), clone->getResults());
60 }
61 return b->create<memref::LoadOp>(loc, arg_bufs.back());
62 }
63
64 // Converts a block with LHLO ops and with signature:
65 // ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
66 // into a reduction operator of scf.reduce by doing buffer allocation for
67 // scalar arguments and the result of `scf.reduce` to make it compatible with
68 // LHLO ops.
ConvertToReductionOperator(Location loc,scf::ReduceOp reduce_op,Block * lhlo_block,OpBuilder * b)69 void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
70 Block* lhlo_block, OpBuilder* b) {
71 Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
72 OpBuilder::InsertionGuard guard(*b);
73 b->setInsertionPointToStart(&loop_reduce_op_body);
74 b->create<scf::ReduceReturnOp>(
75 loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
76 lhlo_block, b));
77 }
78
79 // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
80 // extract dimension at runtime.
GetStaticOrDynamicDim(mlir::Location loc,Value shaped_value,size_t dim_index,int64_t dim,OpBuilder * b)81 Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
82 size_t dim_index, int64_t dim, OpBuilder* b) {
83 return dim == ShapedType::kDynamicSize
84 ? b->create<memref::DimOp>(loc, shaped_value, dim_index)
85 .getResult()
86 : b->create<ConstantIndexOp>(loc, dim);
87 }
88
89 struct MappedIvs {
90 // False if the mapped indices are in the padding area, true otherwise.
91 Value in_bounds;
92 // Mapped indices.
93 SmallVector<Value, 2> ivs;
94 };
95
96 template <typename OpTy>
MapWindowIvsToInput(OpTy op,Value operand,ValueRange ivs,ValueRange window_ivs,OpBuilder * b)97 MappedIvs MapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs,
98 ValueRange window_ivs, OpBuilder* b) {
99 MappedIvs mapped_ivs;
100
101 if (!op.window_strides().hasValue()) {
102 op.emitOpError("No window strides specified.");
103 }
104 auto window_strides = op.window_strides().getValue();
105
106 if (!op.padding().hasValue()) {
107 op.emitOpError("No padding specified.");
108 }
109 auto padding = op.padding().getValue();
110
111 auto loc = op.getLoc();
112 auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
113
114 // `in_bounds` is false when the mapped indices are in the padding area.
115 mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
116 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
117 for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
118 auto stride = window_strides.template getValue<llvm::APInt>(i);
119 auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
120
121 Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
122 Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
123
124 Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
125 Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
126 Value index = b->create<AddIOp>(loc, center, offset);
127 Value upper_bound =
128 GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
129 // We must check whether 0 <= index_i < shape_i, as otherwise we are in
130 // the pad and then we have to use the neutral element for reduction.
131 // Equivalently, it can be computed as the unsigned comparison index_i <
132 // shape_i, since a negative value wraps to a large positive value.
133 mapped_ivs.in_bounds = b->create<mlir::AndOp>(
134 loc, mapped_ivs.in_bounds,
135 b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
136 mapped_ivs.ivs.push_back(index);
137 }
138 return mapped_ivs;
139 }
140
141 // Returns scf::Parallel over a shaped value with static or dynamic shape.
MakeLoopOverShape(Location loc,Value shaped_value,OpBuilder * b)142 scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
143 OpBuilder* b) {
144 Value zero = b->create<ConstantIndexOp>(loc, 0);
145 Value one = b->create<ConstantIndexOp>(loc, 1);
146
147 ArrayRef<int64_t> shape =
148 shaped_value.getType().cast<ShapedType>().getShape();
149 SmallVector<Value, 2> lower, upper, step;
150 for (auto dim : llvm::enumerate(shape)) {
151 upper.push_back(
152 GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
153 lower.push_back(zero);
154 step.push_back(one);
155 }
156 return b->create<scf::ParallelOp>(loc, lower, upper, step);
157 }
158
159 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
160 // The outper `ParallelOp` refers to the parallel loops if there are
161 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
162 // contains the reduction operator.
163 //
164 // Example:
165 //
166 // "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
167 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
168 // <LHLO ops>
169 // } ) {dimensions = dense<[1]> : tensor<1xi64>}
170 // : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
171 //
172 // is roughly converted into:
173 //
174 // %init = load %init_buf[] : memref<f32>
175 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
176 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
177 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
178 // scf.reduce(%elem_to_reduce) {
179 // ^bb0(%elem: f32, %acc: f32): // no predecessors
180 // elem_buf = alloc() : memref<f32>
181 // store %elem, elem_buf[] : memref<f32>
182 // acc_buf = alloc() : memref<f32>
183 // store %acc, acc_buf[] : memref<f32>
184 // <LHLO_ops>
185 // %acc_result = load acc_buf[] : memref<f32>
186 // scf.reduce.return %acc_result : f32
187 // } : f32
188 // scf.yield
189 // } : f32
190 // scf.yield
191 // }
192 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
193 public:
194 using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
195
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const196 LogicalResult matchAndRewrite(
197 lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
198 ConversionPatternRewriter& rewriter) const final {
199 // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
200 if (reduce_op.out().size() != 1) return failure();
201
202 scf::ReduceOp scf_reduce_op =
203 CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
204 ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
205 &reduce_op.body().front(), &rewriter);
206 rewriter.replaceOp(reduce_op, llvm::None);
207 return success();
208 }
209
210 private:
211 // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
212 // refers to the parallel dimensions of `reduce_op` if any and the inner
213 // ParallelOp refers to the reduction dimensions. The scf.reduce op is
214 // returned.
215 //
216 // If the reduction argument is a memref<100x10x5xf32> and the
217 // reduction is performed along dimension 1 then this method will generate
218 //
219 // %init = load %init_buf[] : memref<f32>
220 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
221 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
222 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
223 // scf.reduce(%elem_to_reduce) {
224 // <THE BLOCK PTR TO BE RETURNED>
225 // } : f32
226 // scf.yield
227 // } : f32
228 // scf.yield
229 // }
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduce_op,ConversionPatternRewriter * rewriter) const230 scf::ReduceOp CreateReduceOpInNestedParallelLoops(
231 lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
232 auto loc = reduce_op.getLoc();
233 DenseSet<int> reducing_dims;
234 for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
235 reducing_dims.insert(rdim.getSExtValue());
236 }
237
238 Value operand = reduce_op.inputs().front();
239 Value out = reduce_op.out().front();
240 SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
241 SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
242 auto operand_shape = operand.getType().cast<MemRefType>().getShape();
243 for (auto dim : llvm::enumerate(operand_shape)) {
244 const bool is_reducing_dim = reducing_dims.count(dim.index());
245
246 Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
247 rewriter);
248 Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
249 Value step = rewriter->create<ConstantIndexOp>(loc, 1);
250 (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
251 (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
252 (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
253 }
254 // Load initial value from memref<element_type>.
255 SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
256 loc, *reduce_op.init_values().begin())};
257 // Outer ParallelOp is not needed if it is a reduction across all dims.
258 scf::ParallelOp outer;
259 if (!parallel_lower.empty()) {
260 outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
261 parallel_upper, parallel_step);
262 rewriter->setInsertionPointToStart(outer.getBody());
263 }
264 scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
265 loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
266 Value reduction_result = *inner.getResults().begin();
267
268 SmallVector<Value, 1> out_indices;
269 if (outer != nullptr) {
270 out_indices.reserve(outer.getNumLoops());
271 for (Value iv : outer.getInductionVars()) {
272 out_indices.push_back(iv);
273 }
274 } else {
275 out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
276 }
277
278 rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
279
280 // Load the element to reduce.
281 SmallVector<Value, 2> indices;
282 indices.reserve(operand_shape.size());
283
284 if (outer) {
285 auto inner_ivs_it = inner.getInductionVars().begin();
286 auto outer_ivs_it = outer.getInductionVars().begin();
287 for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
288 indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
289 : *outer_ivs_it++);
290 }
291 } else {
292 indices = inner.getInductionVars();
293 }
294
295 rewriter->setInsertionPointToStart(inner.getBody());
296 Value elem = rewriter->create<mlir::memref::LoadOp>(
297 loc, reduce_op.inputs().front(), indices);
298 return rewriter->create<scf::ReduceOp>(loc, elem);
299 }
300 };
301
302 // Pseudocode:
303 // for each index O in output
304 // accumulator = neutral_value
305 // in_bounds = true
306 // for each index W in window
307 // for each dimension i from 0 to rank - 1
308 // index = O[i] * stride[i] + W[i] - pad_low[i]
309 // in_bounds = inbounds && (index `ult` shape[i])
310 // I[i] = index
311 // if (in_bounds)
312 // value = input[I]
313 // else
314 // value = neutral_value
315 // accumulator = reduction_operator(accumulator, value)
316 // output[O] = accumulator
317 //
318 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
319 // scf::ReduceOp.
320 // The outper `ParallelOp` refers to the parallel loops that traverese output
321 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
322 // reduction windows and `ReduceOp` contains the reduction operator.
323 //
324 // Example:
325 //
326 // func @reduce_window(%arg: memref<112x112xf32>,
327 // %init: memref<f32>,
328 // %result: memref<56x56xf32>) {
329 // "lmhlo.reduce_window"(%arg, %init, %result) ( {
330 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
331 // "lmhlo.maximum"(%lhs, %rhs, %res)
332 // : (memref<f32>, memref<f32>, memref<f32>) -> ()
333 // "lmhlo.terminator"() : () -> ()
334 // }) {
335 // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
336 // window_dimensions = dense<[3, 3]> : tensor<2xi64>,
337 // window_strides = dense<[2, 2]> : tensor<2xi64>
338 // } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
339 // return
340 // }
341 //
342 // is roughly converted into:
343 //
344 // %neutral_elem = load %init_buf[] : memref<f32>
345 // scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
346 // %result = scf.parallel (%iw, %jw) = (%c0, %c0)
347 // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
348 // %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
349 // %elem = load %operand[%computed_i, %computed_j]
350 // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
351 // scf.reduce(%elem_to_reduce) : f32 {
352 // ^bb0(%arg7: f32, %arg8: f32):
353 // <LHLO ops>
354 // }
355 // scf.yield
356 // }
357 // store %result, %output_buffer[%i, %j] : memref<56x56xf32>
358 // scf.yield
359 // }
360 // return
361 // }
362 class ReduceWindowOpConverter
363 : public OpConversionPattern<lmhlo::ReduceWindowOp> {
364 public:
365 using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
366
matchAndRewrite(lmhlo::ReduceWindowOp reduce_window_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const367 LogicalResult matchAndRewrite(
368 lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
369 ConversionPatternRewriter& rewriter) const final {
370 // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
371 if (reduce_window_op.out().size() != 1) return failure();
372
373 scf::ParallelOp output_loop, window_loop;
374 std::tie(output_loop, window_loop) =
375 CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
376 &rewriter);
377
378 scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
379 reduce_window_op, output_loop, window_loop, &rewriter);
380
381 ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
382 &reduce_window_op.body().front(), &rewriter);
383 rewriter.replaceOp(reduce_window_op, llvm::None);
384 return success();
385 }
386
387 private:
388 std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduce_window_op,ConversionPatternRewriter * rewriter) const389 CreateParallelLoopsToTraverseOutputAndWindow(
390 lmhlo::ReduceWindowOp reduce_window_op,
391 ConversionPatternRewriter* rewriter) const {
392 auto loc = reduce_window_op.getLoc();
393 Value init_value = rewriter->create<memref::LoadOp>(
394 loc, reduce_window_op.init_values()[0]);
395
396 Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
397 Value one = rewriter->create<ConstantIndexOp>(loc, 1);
398
399 // Create an outer parallel loop that spans the output of ReduceWindowOp.
400 Value output = reduce_window_op.out()[0];
401 auto output_loop = MakeLoopOverShape(loc, output, rewriter);
402
403 // Create a nested loop that traverses the window.
404 SmallVector<Value, 2> window_lower, window_upper, window_step;
405 rewriter->setInsertionPointToStart(output_loop.getBody());
406 for (const auto& window_dim : reduce_window_op.window_dimensions()) {
407 window_step.push_back(one);
408 window_lower.push_back(zero);
409 window_upper.push_back(
410 rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
411 }
412 auto window_loop = rewriter->create<scf::ParallelOp>(
413 loc, window_lower, window_upper, window_step, ValueRange(init_value));
414
415 Value reduction_result = *window_loop.getResults().begin();
416 auto output_ivs = output_loop.getInductionVars();
417 rewriter->create<memref::StoreOp>(loc, reduction_result, output,
418 output_ivs);
419 return std::make_pair(output_loop, window_loop);
420 }
421
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduce_window_op,scf::ParallelOp output_loop,scf::ParallelOp window_loop,ConversionPatternRewriter * rewriter) const422 scf::ReduceOp CreateReduceOpInNestedParallelLoops(
423 lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
424 scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
425 rewriter->setInsertionPointToStart(window_loop.getBody());
426 auto loc = reduce_window_op.getLoc();
427
428 if (reduce_window_op.base_dilations().hasValue() ||
429 reduce_window_op.window_dilations().hasValue()) {
430 reduce_window_op.emitRemark(
431 "Lowering to parallel loops does not support `base_dilations` or "
432 "`window_dilations` attributes yet. The attributes will be ignored.");
433 }
434
435 Value input = reduce_window_op.inputs()[0];
436 auto input_type = input.getType().cast<MemRefType>();
437
438 // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
439 MappedIvs mapped_ivs = MapWindowIvsToInput(
440 reduce_window_op, input, output_loop.getInductionVars(),
441 window_loop.getInductionVars(), rewriter);
442
443 auto elem_or_init = rewriter->create<scf::IfOp>(
444 loc, input_type.getElementType(), mapped_ivs.in_bounds,
445 /*withElseRegion=*/true);
446
447 OpBuilder then_builder =
448 elem_or_init.getThenBodyBuilder(rewriter->getListener());
449 Value elem =
450 then_builder.create<mlir::memref::LoadOp>(loc, input, mapped_ivs.ivs);
451 then_builder.create<scf::YieldOp>(loc, elem);
452
453 OpBuilder else_builder =
454 elem_or_init.getElseBodyBuilder(rewriter->getListener());
455 else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
456
457 return rewriter->create<scf::ReduceOp>(loc,
458 *elem_or_init.results().begin());
459 }
460 };
461
462 // See the operation semantics in
463 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
464 //
465 // Pseudocode:
466 // scf.parallel(coordinates O in the output):
467 // output[O] = init
468 // scf.parallel(coordinates S in the source):
469 // selected_ivs = 0
470 // selected_val = 0
471 // initialized_flag = false
472 // scf.for (first dim W_1 in the window)
473 // iter_args (selected_ivs, selected_val, initialized_flag):
474 // ...
475 // scf.for (last dim W_N in the window):
476 // iter_args (selected_ivs, selected_val, initialized_flag):
477 // I = S * stride + W - pad_low
478 // if I within bounds of operand:
479 // if (initialized_flag):
480 // pred = select(selected_value, operand(I))):
481 // if (pred)
482 // selected_value = operand(I)
483 // selected_index = I
484 // else
485 // selected_value = operand(I)
486 // selected_index = I
487 // initialized_flag = true
488 // output(selected_index) = scatter(output(selected_index), source(S))
489 class SelectAndScatterOpConverter
490 : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
491 public:
492 using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
493
matchAndRewrite(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const494 LogicalResult matchAndRewrite(
495 lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
496 ConversionPatternRewriter& rewriter) const final {
497 auto loc = s_and_s_op.getLoc();
498 InitializeOutput(s_and_s_op, &rewriter);
499 scf::ParallelOp loop_over_src =
500 MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
501 rewriter.setInsertionPointToStart(loop_over_src.getBody());
502
503 // Compute indices of the selected element in the window.
504 auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
505
506 // Load `source[selected_ivs]`.
507 auto src_elem = rewriter.create<memref::LoadOp>(
508 loc, s_and_s_op.source(), loop_over_src.getInductionVars());
509
510 // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
511 auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
512 selected_ivs);
513 OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
514 auto acc_result =
515 ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
516 &s_and_s_op.scatter().front(), &rmw_builder);
517 rmw_builder.create<AtomicYieldOp>(loc, acc_result);
518
519 rewriter.replaceOp(s_and_s_op, llvm::None);
520 return success();
521 }
522
523 private:
InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,OpBuilder * b) const524 void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
525 OpBuilder* b) const {
526 auto loc = s_and_s_op.getLoc();
527 Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
528
529 scf::ParallelOp loop_over_output =
530 MakeLoopOverShape(loc, s_and_s_op.out(), b);
531 OpBuilder::InsertionGuard guard(*b);
532 b->setInsertionPointToStart(loop_over_output.getBody());
533 b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
534 loop_over_output.getInductionVars());
535 }
536
537 struct WindowLoops {
538 SmallVector<Value, 2> selected_ivs;
539 SmallVector<Value, 2> window_ivs;
540 scf::ForOp inner_loop;
541 };
InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const542 WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
543 scf::ParallelOp loop_over_src,
544 OpBuilder* b) const {
545 auto loc = s_and_s_op.getLoc();
546 Value zero = b->create<ConstantIndexOp>(loc, 0);
547 Value one = b->create<ConstantIndexOp>(loc, 1);
548
549 auto element_type =
550 s_and_s_op.out().getType().cast<MemRefType>().getElementType();
551 auto rank = loop_over_src.getNumLoops();
552
553 // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
554 SmallVector<Value, 4> iter_args(rank, zero);
555 iter_args.push_back(b->create<mlir::ConstantOp>(
556 loc, element_type, b->getFloatAttr(element_type, 0)));
557 iter_args.push_back(b->create<mlir::ConstantOp>(
558 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
559
560 // Create a nested loop that traverses the window.
561 OpBuilder::InsertPoint ip;
562 WindowLoops result;
563 for (const auto& window_dim :
564 s_and_s_op.window_dimensions()->getIntValues()) {
565 Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
566 result.inner_loop =
567 b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
568 if (b->getInsertionBlock() == loop_over_src.getBody()) {
569 ip = b->saveInsertionPoint();
570 result.selected_ivs = result.inner_loop.getResults().take_front(rank);
571 } else {
572 b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
573 }
574 b->setInsertionPointToStart(result.inner_loop.getBody());
575 iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
576 result.window_ivs.push_back(result.inner_loop.getInductionVar());
577 }
578 b->restoreInsertionPoint(ip);
579 return result;
580 }
581
582 // Adapter to store iteration arguments of sequential loops that perform
583 // select in a window.
584 class IterArgs {
585 public:
IterArgs(ValueRange ivs_val_flag)586 explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
IterArgs(ValueRange ivs,Value value,Value flag)587 IterArgs(ValueRange ivs, Value value, Value flag) {
588 ivs_val_flag_ = ivs;
589 ivs_val_flag_.push_back(value);
590 ivs_val_flag_.push_back(flag);
591 }
592
to_vector() const593 ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
594
595 // Indices of the currently selected value.
ivs() const596 ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
597 // Currently selected value w.r.t. select() function.
value() const598 Value value() const { return ivs_val_flag_.end()[-2]; }
599 // i1 flag if value() and ivs() were initialized.
is_init() const600 Value is_init() const { return ivs_val_flag_.back(); }
601
602 private:
603 // Vector that stores iv_1, ..., iv_N, value, init.
604 SmallVector<Value, 4> ivs_val_flag_;
605 };
606
SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const607 SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
608 scf::ParallelOp loop_over_src,
609 OpBuilder* b) const {
610 auto loc = s_and_s_op.getLoc();
611
612 WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
613 auto inner_loop_b =
614 OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
615
616 // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
617 MappedIvs mapped_ivs = MapWindowIvsToInput(
618 s_and_s_op, s_and_s_op.operand(), loop_over_src.getInductionVars(),
619 window_loops.window_ivs, &inner_loop_b);
620
621 IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
622
623 auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
624 loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
625 /*withElseRegion=*/true);
626
627 // Case when we are inside boundaries of 'arg' and not in the pad area.
628 {
629 OpBuilder in_bounds_then_b =
630 if_in_bounds.getThenBodyBuilder(b->getListener());
631 auto select_or_init_results = SelectOrInitialize(
632 s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
633 in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
634 }
635
636 // Case when we are in the pad.
637 {
638 OpBuilder in_bounds_else_b =
639 if_in_bounds.getElseBodyBuilder(b->getListener());
640 in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
641 }
642
643 inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
644 return window_loops.selected_ivs;
645 }
646
SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value> operand_ivs,IterArgs * ivs_val_flag,OpBuilder * b) const647 SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
648 ArrayRef<Value> operand_ivs,
649 IterArgs* ivs_val_flag,
650 OpBuilder* b) const {
651 auto loc = s_and_s_op.getLoc();
652 Value true_i1 = b->create<mlir::ConstantOp>(
653 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
654
655 TypeRange iter_arg_types{ivs_val_flag->to_vector()};
656 Value operand_elem =
657 b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
658 auto if_init =
659 b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
660 /*withElseRegion=*/true);
661 // Init == true, i.e. iter args are already initialized with a selected
662 // element in boundaries of the operand. Select function has to be computed
663 // here.
664 {
665 OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
666
667 auto& lhlo_select = s_and_s_op.select().front();
668 Value pred =
669 ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
670 &lhlo_select, &if_init_then_b);
671
672 auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
673 /*withElseRegion=*/true);
674
675 // Pred == true, therefore pack newly selected ivs, val and init flag back
676 // to iter_args and return.
677 {
678 OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
679 if_pred_then_b.create<scf::YieldOp>(
680 loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
681 }
682
683 // Pred == false, therefore return old iter_args.
684 {
685 OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
686 if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
687 }
688
689 if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
690 }
691 // Init == false, i.e. only pad was visited before and this is the first
692 // element in the boundaries of the operand.
693 {
694 OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
695
696 if_init_else_b.create<scf::YieldOp>(
697 loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
698 }
699 return if_init.getResults();
700 }
701 };
702
703 struct LhloLegalizeToParallelLoopsPass
704 : public LhloLegalizeToParallelLoopsPassBase<
705 LhloLegalizeToParallelLoopsPass> {
getDependentDialectsmlir::lmhlo::__anon2b16ba110111::LhloLegalizeToParallelLoopsPass706 void getDependentDialects(DialectRegistry& registry) const override {
707 registry
708 .insert<StandardOpsDialect, memref::MemRefDialect, scf::SCFDialect>();
709 }
710
runOnFunctionmlir::lmhlo::__anon2b16ba110111::LhloLegalizeToParallelLoopsPass711 void runOnFunction() override {
712 auto func = getFunction();
713
714 OwningRewritePatternList patterns(&getContext());
715 // clang-format off
716 patterns.insert<
717 ReduceOpConverter,
718 ReduceWindowOpConverter,
719 SelectAndScatterOpConverter
720 >(func.getContext());
721 // clang-format on
722
723 ConversionTarget target(getContext());
724 target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
725 StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
726 target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
727 lmhlo::SelectAndScatterOp>();
728
729 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
730 signalPassFailure();
731 }
732 }
733 };
734 } // namespace
735
createLegalizeLhloToParallelLoopsPass()736 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
737 return std::make_unique<LhloLegalizeToParallelLoopsPass>();
738 }
739
740 } // namespace lmhlo
741 } // namespace mlir
742