• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/SCF/SCF.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 
29 namespace mlir {
30 namespace lmhlo {
31 namespace {
32 
33 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
34 // single output buffer to make it compatible with `operands` that have element
35 // types of the respective buffers. Returns the computed value.
36 //
37 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
38 // with signature:
39 //   ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
40 //     <LHLO_ops>
41 //
42 // inserts necessary alloc and store ops to compute and return result that has
43 // `i1` type.
ApplySingleResultLhloCode(Location loc,ValueRange operands,Block * lhlo_block,OpBuilder * b)44 Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
45                                 Block* lhlo_block, OpBuilder* b) {
46   SmallVector<Value, 2> arg_bufs;
47   for (auto arg_type : lhlo_block->getArgumentTypes()) {
48     arg_bufs.push_back(
49         b->create<memref::AllocOp>(loc, arg_type.cast<MemRefType>()));
50   }
51   for (auto operand : llvm::enumerate(operands)) {
52     b->create<memref::StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
53   }
54   // Clone the ops from `lhlo_block`.
55   BlockAndValueMapping mapping;
56   mapping.map(lhlo_block->getArguments(), arg_bufs);
57   for (auto& nested : lhlo_block->without_terminator()) {
58     auto clone = b->clone(nested, mapping);
59     mapping.map(nested.getResults(), clone->getResults());
60   }
61   return b->create<memref::LoadOp>(loc, arg_bufs.back());
62 }
63 
64 // Converts a block with LHLO ops and with signature:
65 //   ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
66 // into a reduction operator of scf.reduce by doing buffer allocation for
67 // scalar arguments and the result of `scf.reduce` to make it compatible with
68 // LHLO ops.
ConvertToReductionOperator(Location loc,scf::ReduceOp reduce_op,Block * lhlo_block,OpBuilder * b)69 void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
70                                 Block* lhlo_block, OpBuilder* b) {
71   Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
72   OpBuilder::InsertionGuard guard(*b);
73   b->setInsertionPointToStart(&loop_reduce_op_body);
74   b->create<scf::ReduceReturnOp>(
75       loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
76                                      lhlo_block, b));
77 }
78 
79 // Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
80 // extract dimension at runtime.
GetStaticOrDynamicDim(mlir::Location loc,Value shaped_value,size_t dim_index,int64_t dim,OpBuilder * b)81 Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
82                             size_t dim_index, int64_t dim, OpBuilder* b) {
83   return dim == ShapedType::kDynamicSize
84              ? b->create<memref::DimOp>(loc, shaped_value, dim_index)
85                    .getResult()
86              : b->create<ConstantIndexOp>(loc, dim);
87 }
88 
89 struct MappedIvs {
90   // False if the mapped indices are in the padding area, true otherwise.
91   Value in_bounds;
92   // Mapped indices.
93   SmallVector<Value, 2> ivs;
94 };
95 
96 template <typename OpTy>
MapWindowIvsToInput(OpTy op,Value operand,ValueRange ivs,ValueRange window_ivs,OpBuilder * b)97 MappedIvs MapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs,
98                               ValueRange window_ivs, OpBuilder* b) {
99   MappedIvs mapped_ivs;
100 
101   if (!op.window_strides().hasValue()) {
102     op.emitOpError("No window strides specified.");
103   }
104   auto window_strides = op.window_strides().getValue();
105 
106   if (!op.padding().hasValue()) {
107     op.emitOpError("No padding specified.");
108   }
109   auto padding = op.padding().getValue();
110 
111   auto loc = op.getLoc();
112   auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
113 
114   // `in_bounds` is false when the mapped indices are in the padding area.
115   mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
116       loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
117   for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
118     auto stride = window_strides.template getValue<llvm::APInt>(i);
119     auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
120 
121     Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
122     Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
123 
124     Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
125     Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
126     Value index = b->create<AddIOp>(loc, center, offset);
127     Value upper_bound =
128         GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
129     // We must check whether 0 <= index_i < shape_i, as otherwise we are in
130     // the pad and then we have to use the neutral element for reduction.
131     // Equivalently, it can be computed as the unsigned comparison index_i <
132     // shape_i, since a negative value wraps to a large positive value.
133     mapped_ivs.in_bounds = b->create<mlir::AndOp>(
134         loc, mapped_ivs.in_bounds,
135         b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
136     mapped_ivs.ivs.push_back(index);
137   }
138   return mapped_ivs;
139 }
140 
141 // Returns scf::Parallel over a shaped value with static or dynamic shape.
MakeLoopOverShape(Location loc,Value shaped_value,OpBuilder * b)142 scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
143                                   OpBuilder* b) {
144   Value zero = b->create<ConstantIndexOp>(loc, 0);
145   Value one = b->create<ConstantIndexOp>(loc, 1);
146 
147   ArrayRef<int64_t> shape =
148       shaped_value.getType().cast<ShapedType>().getShape();
149   SmallVector<Value, 2> lower, upper, step;
150   for (auto dim : llvm::enumerate(shape)) {
151     upper.push_back(
152         GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
153     lower.push_back(zero);
154     step.push_back(one);
155   }
156   return b->create<scf::ParallelOp>(loc, lower, upper, step);
157 }
158 
159 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
160 // The outper `ParallelOp` refers to the parallel loops if there are
161 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
162 // contains the reduction operator.
163 //
164 // Example:
165 //
166 //  "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
167 //    ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
168 //      <LHLO ops>
169 //    } ) {dimensions = dense<[1]> : tensor<1xi64>}
170 //      : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
171 //
172 //  is roughly converted into:
173 //
174 //  %init = load %init_buf[] : memref<f32>
175 //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
176 //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
177 //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
178 //      scf.reduce(%elem_to_reduce)  {
179 //        ^bb0(%elem: f32, %acc: f32):   // no predecessors
180 //          elem_buf = alloc() : memref<f32>
181 //          store %elem, elem_buf[] : memref<f32>
182 //          acc_buf = alloc() : memref<f32>
183 //          store %acc, acc_buf[] : memref<f32>
184 //          <LHLO_ops>
185 //          %acc_result = load acc_buf[] : memref<f32>
186 //          scf.reduce.return %acc_result : f32
187 //      } : f32
188 //      scf.yield
189 //    } : f32
190 //    scf.yield
191 //  }
192 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
193  public:
194   using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
195 
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const196   LogicalResult matchAndRewrite(
197       lmhlo::ReduceOp reduce_op, ArrayRef<Value> /*args*/,
198       ConversionPatternRewriter& rewriter) const final {
199     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
200     if (reduce_op.out().size() != 1) return failure();
201 
202     scf::ReduceOp scf_reduce_op =
203         CreateReduceOpInNestedParallelLoops(reduce_op, &rewriter);
204     ConvertToReductionOperator(reduce_op.getLoc(), scf_reduce_op,
205                                &reduce_op.body().front(), &rewriter);
206     rewriter.replaceOp(reduce_op, llvm::None);
207     return success();
208   }
209 
210  private:
211   // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
212   // refers to the parallel dimensions of `reduce_op` if any and the inner
213   // ParallelOp refers to the reduction dimensions. The scf.reduce op is
214   // returned.
215   //
216   // If the reduction argument is a memref<100x10x5xf32> and the
217   // reduction is performed along dimension 1 then this method will generate
218   //
219   //  %init = load %init_buf[] : memref<f32>
220   //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
221   //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
222   //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
223   //      scf.reduce(%elem_to_reduce)  {
224   //        <THE BLOCK PTR TO BE RETURNED>
225   //      } : f32
226   //      scf.yield
227   //    } : f32
228   //    scf.yield
229   //  }
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduce_op,ConversionPatternRewriter * rewriter) const230   scf::ReduceOp CreateReduceOpInNestedParallelLoops(
231       lmhlo::ReduceOp reduce_op, ConversionPatternRewriter* rewriter) const {
232     auto loc = reduce_op.getLoc();
233     DenseSet<int> reducing_dims;
234     for (const auto& rdim : reduce_op.dimensions().getIntValues()) {
235       reducing_dims.insert(rdim.getSExtValue());
236     }
237 
238     Value operand = reduce_op.inputs().front();
239     Value out = reduce_op.out().front();
240     SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
241     SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
242     auto operand_shape = operand.getType().cast<MemRefType>().getShape();
243     for (auto dim : llvm::enumerate(operand_shape)) {
244       const bool is_reducing_dim = reducing_dims.count(dim.index());
245 
246       Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
247                                        rewriter);
248       Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
249       Value step = rewriter->create<ConstantIndexOp>(loc, 1);
250       (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
251       (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
252       (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
253     }
254     // Load initial value from memref<element_type>.
255     SmallVector<Value, 1> init_value = {rewriter->create<memref::LoadOp>(
256         loc, *reduce_op.init_values().begin())};
257     // Outer ParallelOp is not needed if it is a reduction across all dims.
258     scf::ParallelOp outer;
259     if (!parallel_lower.empty()) {
260       outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
261                                                 parallel_upper, parallel_step);
262       rewriter->setInsertionPointToStart(outer.getBody());
263     }
264     scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
265         loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
266     Value reduction_result = *inner.getResults().begin();
267 
268     SmallVector<Value, 1> out_indices;
269     if (outer != nullptr) {
270       out_indices.reserve(outer.getNumLoops());
271       for (Value iv : outer.getInductionVars()) {
272         out_indices.push_back(iv);
273       }
274     } else {
275       out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
276     }
277 
278     rewriter->create<memref::StoreOp>(loc, reduction_result, out, out_indices);
279 
280     // Load the element to reduce.
281     SmallVector<Value, 2> indices;
282     indices.reserve(operand_shape.size());
283 
284     if (outer) {
285       auto inner_ivs_it = inner.getInductionVars().begin();
286       auto outer_ivs_it = outer.getInductionVars().begin();
287       for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
288         indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
289                                                  : *outer_ivs_it++);
290       }
291     } else {
292       indices = inner.getInductionVars();
293     }
294 
295     rewriter->setInsertionPointToStart(inner.getBody());
296     Value elem = rewriter->create<mlir::memref::LoadOp>(
297         loc, reduce_op.inputs().front(), indices);
298     return rewriter->create<scf::ReduceOp>(loc, elem);
299   }
300 };
301 
302 // Pseudocode:
303 // for each index O in output
304 //   accumulator = neutral_value
305 //   in_bounds = true
306 //   for each index W in window
307 //     for each dimension i from 0 to rank - 1
308 //       index = O[i] * stride[i] + W[i] - pad_low[i]
309 //       in_bounds = inbounds && (index `ult` shape[i])
310 //       I[i] = index
311 //     if (in_bounds)
312 //       value = input[I]
313 //     else
314 //       value = neutral_value
315 //     accumulator = reduction_operator(accumulator, value)
316 //   output[O] = accumulator
317 //
318 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
319 // scf::ReduceOp.
320 // The outper `ParallelOp` refers to the parallel loops that traverese output
321 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
322 // reduction windows and `ReduceOp` contains the reduction operator.
323 //
324 // Example:
325 //
326 // func @reduce_window(%arg: memref<112x112xf32>,
327 //              %init: memref<f32>,
328 //              %result: memref<56x56xf32>) {
329 //   "lmhlo.reduce_window"(%arg, %init, %result) ( {
330 //     ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
331 //       "lmhlo.maximum"(%lhs, %rhs, %res)
332 //         : (memref<f32>, memref<f32>, memref<f32>) -> ()
333 //       "lmhlo.terminator"() : () -> ()
334 //     }) {
335 //       padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
336 //       window_dimensions = dense<[3, 3]> : tensor<2xi64>,
337 //       window_strides = dense<[2, 2]> : tensor<2xi64>
338 //     } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
339 //   return
340 // }
341 //
342 // is roughly converted into:
343 //
344 //    %neutral_elem = load %init_buf[] : memref<f32>
345 //    scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
346 //      %result = scf.parallel (%iw, %jw) = (%c0, %c0)
347 //                  to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
348 //        %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
349 //        %elem = load %operand[%computed_i, %computed_j]
350 //        %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
351 //        scf.reduce(%elem_to_reduce)  : f32 {
352 //          ^bb0(%arg7: f32, %arg8: f32):
353 //            <LHLO ops>
354 //        }
355 //        scf.yield
356 //      }
357 //      store %result, %output_buffer[%i, %j] : memref<56x56xf32>
358 //      scf.yield
359 //    }
360 //    return
361 //  }
362 class ReduceWindowOpConverter
363     : public OpConversionPattern<lmhlo::ReduceWindowOp> {
364  public:
365   using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
366 
matchAndRewrite(lmhlo::ReduceWindowOp reduce_window_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const367   LogicalResult matchAndRewrite(
368       lmhlo::ReduceWindowOp reduce_window_op, ArrayRef<Value> /*args*/,
369       ConversionPatternRewriter& rewriter) const final {
370     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
371     if (reduce_window_op.out().size() != 1) return failure();
372 
373     scf::ParallelOp output_loop, window_loop;
374     std::tie(output_loop, window_loop) =
375         CreateParallelLoopsToTraverseOutputAndWindow(reduce_window_op,
376                                                      &rewriter);
377 
378     scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
379         reduce_window_op, output_loop, window_loop, &rewriter);
380 
381     ConvertToReductionOperator(reduce_window_op.getLoc(), reduce_op,
382                                &reduce_window_op.body().front(), &rewriter);
383     rewriter.replaceOp(reduce_window_op, llvm::None);
384     return success();
385   }
386 
387  private:
388   std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduce_window_op,ConversionPatternRewriter * rewriter) const389   CreateParallelLoopsToTraverseOutputAndWindow(
390       lmhlo::ReduceWindowOp reduce_window_op,
391       ConversionPatternRewriter* rewriter) const {
392     auto loc = reduce_window_op.getLoc();
393     Value init_value = rewriter->create<memref::LoadOp>(
394         loc, reduce_window_op.init_values()[0]);
395 
396     Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
397     Value one = rewriter->create<ConstantIndexOp>(loc, 1);
398 
399     // Create an outer parallel loop that spans the output of ReduceWindowOp.
400     Value output = reduce_window_op.out()[0];
401     auto output_loop = MakeLoopOverShape(loc, output, rewriter);
402 
403     // Create a nested loop that traverses the window.
404     SmallVector<Value, 2> window_lower, window_upper, window_step;
405     rewriter->setInsertionPointToStart(output_loop.getBody());
406     for (const auto& window_dim : reduce_window_op.window_dimensions()) {
407       window_step.push_back(one);
408       window_lower.push_back(zero);
409       window_upper.push_back(
410           rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
411     }
412     auto window_loop = rewriter->create<scf::ParallelOp>(
413         loc, window_lower, window_upper, window_step, ValueRange(init_value));
414 
415     Value reduction_result = *window_loop.getResults().begin();
416     auto output_ivs = output_loop.getInductionVars();
417     rewriter->create<memref::StoreOp>(loc, reduction_result, output,
418                                       output_ivs);
419     return std::make_pair(output_loop, window_loop);
420   }
421 
CreateReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduce_window_op,scf::ParallelOp output_loop,scf::ParallelOp window_loop,ConversionPatternRewriter * rewriter) const422   scf::ReduceOp CreateReduceOpInNestedParallelLoops(
423       lmhlo::ReduceWindowOp reduce_window_op, scf::ParallelOp output_loop,
424       scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
425     rewriter->setInsertionPointToStart(window_loop.getBody());
426     auto loc = reduce_window_op.getLoc();
427 
428     if (reduce_window_op.base_dilations().hasValue() ||
429         reduce_window_op.window_dilations().hasValue()) {
430       reduce_window_op.emitRemark(
431           "Lowering to parallel loops does not support `base_dilations` or "
432           "`window_dilations` attributes yet. The attributes will be ignored.");
433     }
434 
435     Value input = reduce_window_op.inputs()[0];
436     auto input_type = input.getType().cast<MemRefType>();
437 
438     // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
439     MappedIvs mapped_ivs = MapWindowIvsToInput(
440         reduce_window_op, input, output_loop.getInductionVars(),
441         window_loop.getInductionVars(), rewriter);
442 
443     auto elem_or_init = rewriter->create<scf::IfOp>(
444         loc, input_type.getElementType(), mapped_ivs.in_bounds,
445         /*withElseRegion=*/true);
446 
447     OpBuilder then_builder =
448         elem_or_init.getThenBodyBuilder(rewriter->getListener());
449     Value elem =
450         then_builder.create<mlir::memref::LoadOp>(loc, input, mapped_ivs.ivs);
451     then_builder.create<scf::YieldOp>(loc, elem);
452 
453     OpBuilder else_builder =
454         elem_or_init.getElseBodyBuilder(rewriter->getListener());
455     else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
456 
457     return rewriter->create<scf::ReduceOp>(loc,
458                                            *elem_or_init.results().begin());
459   }
460 };
461 
462 // See the operation semantics in
463 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
464 //
465 // Pseudocode:
466 //  scf.parallel(coordinates O in the output):
467 //    output[O] = init
468 //  scf.parallel(coordinates S in the source):
469 //    selected_ivs = 0
470 //    selected_val = 0
471 //    initialized_flag = false
472 //    scf.for (first dim W_1 in the window)
473 //         iter_args (selected_ivs, selected_val, initialized_flag):
474 //    ...
475 //      scf.for (last dim W_N in the window):
476 //           iter_args (selected_ivs, selected_val, initialized_flag):
477 //        I = S * stride + W - pad_low
478 //        if I within bounds of operand:
479 //          if (initialized_flag):
480 //            pred = select(selected_value, operand(I))):
481 //            if (pred)
482 //              selected_value = operand(I)
483 //              selected_index = I
484 //          else
485 //              selected_value = operand(I)
486 //              selected_index = I
487 //              initialized_flag = true
488 //    output(selected_index) = scatter(output(selected_index), source(S))
489 class SelectAndScatterOpConverter
490     : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
491  public:
492   using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
493 
matchAndRewrite(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const494   LogicalResult matchAndRewrite(
495       lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
496       ConversionPatternRewriter& rewriter) const final {
497     auto loc = s_and_s_op.getLoc();
498     InitializeOutput(s_and_s_op, &rewriter);
499     scf::ParallelOp loop_over_src =
500         MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
501     rewriter.setInsertionPointToStart(loop_over_src.getBody());
502 
503     // Compute indices of the selected element in the window.
504     auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
505 
506     // Load `source[selected_ivs]`.
507     auto src_elem = rewriter.create<memref::LoadOp>(
508         loc, s_and_s_op.source(), loop_over_src.getInductionVars());
509 
510     // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
511     auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
512                                                    selected_ivs);
513     OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
514     auto acc_result =
515         ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
516                                   &s_and_s_op.scatter().front(), &rmw_builder);
517     rmw_builder.create<AtomicYieldOp>(loc, acc_result);
518 
519     rewriter.replaceOp(s_and_s_op, llvm::None);
520     return success();
521   }
522 
523  private:
InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,OpBuilder * b) const524   void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
525                         OpBuilder* b) const {
526     auto loc = s_and_s_op.getLoc();
527     Value init_value = b->create<memref::LoadOp>(loc, s_and_s_op.init_value());
528 
529     scf::ParallelOp loop_over_output =
530         MakeLoopOverShape(loc, s_and_s_op.out(), b);
531     OpBuilder::InsertionGuard guard(*b);
532     b->setInsertionPointToStart(loop_over_output.getBody());
533     b->create<memref::StoreOp>(loc, init_value, s_and_s_op.out(),
534                                loop_over_output.getInductionVars());
535   }
536 
537   struct WindowLoops {
538     SmallVector<Value, 2> selected_ivs;
539     SmallVector<Value, 2> window_ivs;
540     scf::ForOp inner_loop;
541   };
InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const542   WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
543                                 scf::ParallelOp loop_over_src,
544                                 OpBuilder* b) const {
545     auto loc = s_and_s_op.getLoc();
546     Value zero = b->create<ConstantIndexOp>(loc, 0);
547     Value one = b->create<ConstantIndexOp>(loc, 1);
548 
549     auto element_type =
550         s_and_s_op.out().getType().cast<MemRefType>().getElementType();
551     auto rank = loop_over_src.getNumLoops();
552 
553     // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
554     SmallVector<Value, 4> iter_args(rank, zero);
555     iter_args.push_back(b->create<mlir::ConstantOp>(
556         loc, element_type, b->getFloatAttr(element_type, 0)));
557     iter_args.push_back(b->create<mlir::ConstantOp>(
558         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
559 
560     // Create a nested loop that traverses the window.
561     OpBuilder::InsertPoint ip;
562     WindowLoops result;
563     for (const auto& window_dim :
564          s_and_s_op.window_dimensions()->getIntValues()) {
565       Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
566       result.inner_loop =
567           b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
568       if (b->getInsertionBlock() == loop_over_src.getBody()) {
569         ip = b->saveInsertionPoint();
570         result.selected_ivs = result.inner_loop.getResults().take_front(rank);
571       } else {
572         b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
573       }
574       b->setInsertionPointToStart(result.inner_loop.getBody());
575       iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
576       result.window_ivs.push_back(result.inner_loop.getInductionVar());
577     }
578     b->restoreInsertionPoint(ip);
579     return result;
580   }
581 
582   // Adapter to store iteration arguments of sequential loops that perform
583   // select in a window.
584   class IterArgs {
585    public:
IterArgs(ValueRange ivs_val_flag)586     explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
IterArgs(ValueRange ivs,Value value,Value flag)587     IterArgs(ValueRange ivs, Value value, Value flag) {
588       ivs_val_flag_ = ivs;
589       ivs_val_flag_.push_back(value);
590       ivs_val_flag_.push_back(flag);
591     }
592 
to_vector() const593     ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
594 
595     // Indices of the currently selected value.
ivs() const596     ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
597     // Currently selected value w.r.t. select() function.
value() const598     Value value() const { return ivs_val_flag_.end()[-2]; }
599     // i1 flag if value() and ivs() were initialized.
is_init() const600     Value is_init() const { return ivs_val_flag_.back(); }
601 
602    private:
603     // Vector that stores iv_1, ..., iv_N, value, init.
604     SmallVector<Value, 4> ivs_val_flag_;
605   };
606 
SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,scf::ParallelOp loop_over_src,OpBuilder * b) const607   SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
608                                   scf::ParallelOp loop_over_src,
609                                   OpBuilder* b) const {
610     auto loc = s_and_s_op.getLoc();
611 
612     WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
613     auto inner_loop_b =
614         OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
615 
616     // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
617     MappedIvs mapped_ivs = MapWindowIvsToInput(
618         s_and_s_op, s_and_s_op.operand(), loop_over_src.getInductionVars(),
619         window_loops.window_ivs, &inner_loop_b);
620 
621     IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
622 
623     auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
624         loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
625         /*withElseRegion=*/true);
626 
627     // Case when we are inside boundaries of 'arg' and not in the pad area.
628     {
629       OpBuilder in_bounds_then_b =
630           if_in_bounds.getThenBodyBuilder(b->getListener());
631       auto select_or_init_results = SelectOrInitialize(
632           s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
633       in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
634     }
635 
636     // Case when we are in the pad.
637     {
638       OpBuilder in_bounds_else_b =
639           if_in_bounds.getElseBodyBuilder(b->getListener());
640       in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
641     }
642 
643     inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
644     return window_loops.selected_ivs;
645   }
646 
SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,ArrayRef<Value> operand_ivs,IterArgs * ivs_val_flag,OpBuilder * b) const647   SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
648                                            ArrayRef<Value> operand_ivs,
649                                            IterArgs* ivs_val_flag,
650                                            OpBuilder* b) const {
651     auto loc = s_and_s_op.getLoc();
652     Value true_i1 = b->create<mlir::ConstantOp>(
653         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
654 
655     TypeRange iter_arg_types{ivs_val_flag->to_vector()};
656     Value operand_elem =
657         b->create<memref::LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
658     auto if_init =
659         b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
660                              /*withElseRegion=*/true);
661     // Init == true, i.e. iter args are already initialized with a selected
662     // element in boundaries of the operand. Select function has to be computed
663     // here.
664     {
665       OpBuilder if_init_then_b = if_init.getThenBodyBuilder(b->getListener());
666 
667       auto& lhlo_select = s_and_s_op.select().front();
668       Value pred =
669           ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
670                                     &lhlo_select, &if_init_then_b);
671 
672       auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
673                                                       /*withElseRegion=*/true);
674 
675       // Pred == true, therefore pack newly selected ivs, val and init flag back
676       // to iter_args and return.
677       {
678         OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(b->getListener());
679         if_pred_then_b.create<scf::YieldOp>(
680             loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
681       }
682 
683       // Pred == false, therefore return old iter_args.
684       {
685         OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(b->getListener());
686         if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
687       }
688 
689       if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
690     }
691     // Init == false, i.e. only pad was visited before and this is the first
692     // element in the boundaries of the operand.
693     {
694       OpBuilder if_init_else_b = if_init.getElseBodyBuilder(b->getListener());
695 
696       if_init_else_b.create<scf::YieldOp>(
697           loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
698     }
699     return if_init.getResults();
700   }
701 };
702 
703 struct LhloLegalizeToParallelLoopsPass
704     : public LhloLegalizeToParallelLoopsPassBase<
705           LhloLegalizeToParallelLoopsPass> {
getDependentDialectsmlir::lmhlo::__anon2b16ba110111::LhloLegalizeToParallelLoopsPass706   void getDependentDialects(DialectRegistry& registry) const override {
707     registry
708         .insert<StandardOpsDialect, memref::MemRefDialect, scf::SCFDialect>();
709   }
710 
runOnFunctionmlir::lmhlo::__anon2b16ba110111::LhloLegalizeToParallelLoopsPass711   void runOnFunction() override {
712     auto func = getFunction();
713 
714     OwningRewritePatternList patterns(&getContext());
715     // clang-format off
716     patterns.insert<
717         ReduceOpConverter,
718         ReduceWindowOpConverter,
719         SelectAndScatterOpConverter
720       >(func.getContext());
721     // clang-format on
722 
723     ConversionTarget target(getContext());
724     target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
725                            StandardOpsDialect, scf::SCFDialect, LmhloDialect>();
726     target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
727                         lmhlo::SelectAndScatterOp>();
728 
729     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
730       signalPassFailure();
731     }
732   }
733 };
734 }  // namespace
735 
createLegalizeLhloToParallelLoopsPass()736 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
737   return std::make_unique<LhloLegalizeToParallelLoopsPass>();
738 }
739 
740 }  // namespace lmhlo
741 }  // namespace mlir
742