• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
26 #include "mlir/Dialect/Affine/IR/AffineOps.h"
27 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
28 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
29 #include "mlir/Dialect/Math/IR/Math.h"
30 #include "mlir/Dialect/SCF/SCF.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/AffineExpr.h"
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/Builders.h"
36 #include "mlir/IR/BuiltinOps.h"
37 #include "mlir/IR/BuiltinTypes.h"
38 #include "mlir/IR/Location.h"
39 #include "mlir/IR/MLIRContext.h"
40 #include "mlir/IR/Matchers.h"
41 #include "mlir/IR/Operation.h"
42 #include "mlir/IR/OperationSupport.h"
43 #include "mlir/IR/PatternMatch.h"
44 #include "mlir/IR/TypeUtilities.h"
45 #include "mlir/Pass/Pass.h"
46 #include "mlir/Pass/PassManager.h"
47 #include "mlir/Transforms/DialectConversion.h"
48 
49 namespace mlir {
50 namespace {
51 
52 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
53 /// are "parallel" except the last `nReduction` elements, where are "reduction"
54 /// attributes.
GetParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)55 SmallVector<StringRef, 3> GetParallelAndReductionIterators(
56     unsigned nLoops, unsigned nReduction) {
57   SmallVector<StringRef, 3> res(nLoops - nReduction,
58                                 getParallelIteratorTypeName());
59   res.append(nReduction, getReductionIteratorTypeName());
60   return res;
61 }
62 
GetNParallelLoopsAttrs(unsigned nParallelLoops)63 SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
64   return GetParallelAndReductionIterators(nParallelLoops, 0);
65 }
66 
67 template <bool isLHLO = true>
GetResultValue(Operation * op)68 Value GetResultValue(Operation* op) {
69   return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
70 }
71 
72 template <bool isLHLO = true>
GetHloOpResultType(Operation * op)73 ShapedType GetHloOpResultType(Operation* op) {
74   return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
75 }
76 
77 template <bool isLHLO = true>
VerifyHloOpBufferOrTensorSemantics(Operation * op)78 bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
79   auto verify_type = [&](Value val) -> bool {
80     return (isLHLO && val.getType().isa<MemRefType>()) ||
81            (!isLHLO && val.getType().isa<RankedTensorType>());
82   };
83   if (!llvm::all_of(op->getOperands(), verify_type)) return false;
84   return isLHLO ? op->getResults().empty()
85                 : llvm::all_of(op->getResults(), verify_type);
86 }
87 
GetInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dyn_sizes)88 Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
89                     ArrayRef<Value> dyn_sizes) {
90   return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
91                                         type.getElementType());
92 }
93 
ExtractDynamicSizes(OpBuilder & b,Location loc,Value tensor)94 SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
95                                           Value tensor) {
96   auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
97   if (!tensor_type) return {};
98   SmallVector<Value, 2> dyn_sizes;
99   for (auto& en : llvm::enumerate(tensor_type.getShape())) {
100     if (en.value() != ShapedType::kDynamicSize) continue;
101     dyn_sizes.push_back(b.create<DimOp>(loc, tensor, en.index()));
102   }
103   return dyn_sizes;
104 }
105 
Extract1DVector(DenseIntElementsAttr elements)106 SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
107   SmallVector<int64_t, 4> ret;
108   for (const APInt& element : elements) {
109     ret.push_back(element.getLimitedValue());
110   }
111   return ret;
112 }
113 
114 /// Returns the constant value associated with the init value if the defining
115 /// operation is a constant.
GetInitValueAsConst(Value init)116 Attribute GetInitValueAsConst(Value init) {
117   DenseElementsAttr attr;
118   if (!matchPattern(init, m_Constant(&attr))) return {};
119   auto type = attr.getType().dyn_cast<ShapedType>();
120   if (!type || type.getRank() != 0) return {};
121   return attr.getValue({});
122 }
123 
124 /// Returns a permutation AffineMap that puts all reduction dimensions to the
125 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
126 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
127 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
128 /// the AffineMap is returned.
GetTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reduction_dims)129 AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
130                                       ArrayRef<int64_t> reduction_dims) {
131   llvm::SmallSetVector<int, 4> s;
132   for (auto dim : reduction_dims) s.insert(dim);
133 
134   SmallVector<unsigned, 4> permutation;
135   for (int i = 0; i < rank; ++i)
136     if (!s.count(i)) permutation.push_back(i);
137   for (auto dim : reduction_dims) permutation.push_back(dim);
138 
139   auto map = AffineMap::getPermutationMap(permutation, context);
140   return inversePermutation(map);
141 }
142 
143 template <typename OpTy, bool isLHLO = true>
144 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
145  public:
146   using OpConversionPattern<OpTy>::OpConversionPattern;
147 
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const148   LogicalResult matchAndRewrite(
149       OpTy op, ArrayRef<Value> args,
150       ConversionPatternRewriter& rewriter) const final {
151     auto loc = op.getLoc();
152     ShapedType t0 = args[0].getType().template dyn_cast<ShapedType>();
153     if (!t0) return failure();
154 
155     unsigned nloops = t0.getRank();
156     auto fail = [&](ShapedType t) {
157       return !t || !t.hasRank() || t.getRank() != nloops ||
158              !(t.getElementType().isSignlessIntOrFloat() ||
159                t.getElementType().isa<ComplexType>());
160     };
161     if (llvm::any_of(args,
162                      [&](Value v) {
163                        return fail(v.getType().dyn_cast<ShapedType>());
164                      }) ||
165         llvm::any_of(op.getOperation()->getResultTypes(),
166                      [&](Type t) { return fail(t.dyn_cast<ShapedType>()); }))
167       return emitError(loc,
168                        "lhlo to linalg conversion expects ranked args of "
169                        "signless int, float or complex element type with ")
170              << nloops << " parallel iterators: " << *(op.getOperation());
171 
172     // Construct the indexing maps needed for linalg.generic ops.
173     SmallVector<Type, 4> body_arg_types, body_result_types, op_result_types;
174 
175     // This doesnt account for implicit broadcast, but the working assumption
176     // in HLO/LHLO is that are broadcasts are made explicit.
177 
178     if (isLHLO && !nloops) return failure();
179 
180     int num_inputs = (isLHLO ? args.size() - 1 : args.size());
181 
182     ValueRange inputs(args.take_front(num_inputs));
183     for (Value in : inputs)
184       body_arg_types.emplace_back(getElementTypeOrSelf(in.getType()));
185 
186     SmallVector<Value, 4> output_buffers;
187     if (isLHLO) {
188       output_buffers.append(args.begin() + num_inputs, args.end());
189     } else {
190       Value result = op.getOperation()->getResult(0);
191       ShapedType result_type = result.getType().template cast<ShapedType>();
192       auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, args[0]);
193       output_buffers.push_back(
194           GetInitTensor(rewriter, loc, result_type, dyn_sizes));
195       op_result_types.push_back(result.getType());
196     }
197     body_result_types = llvm::to_vector<4>(llvm::map_range(
198         output_buffers, [](Value v) { return getElementTypeOrSelf(v); }));
199 
200     AffineMap common_indexing_map =
201         nloops ? rewriter.getMultiDimIdentityMap(nloops)
202                : AffineMap::get(nloops, 0, rewriter.getContext());
203     SmallVector<AffineMap, 2> indexing_maps(args.size() + (isLHLO ? 0 : 1),
204                                             common_indexing_map);
205 
206     bool failed = false;
207     auto linalg_op = rewriter.create<linalg::GenericOp>(
208         loc, op_result_types, inputs, output_buffers, indexing_maps,
209         GetNParallelLoopsAttrs(nloops),
210         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
211           // TODO(ravishankarm) : For now use the method in lmhlo namespace.
212           // That method needs to be moved out of there.
213           Value op_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
214               op, body_result_types,
215               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
216           if (op_result == nullptr) {
217             failed = true;
218           } else {
219             nested_builder.create<linalg::YieldOp>(loc, op_result);
220           }
221         });
222     if (failed) return failure();
223     rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
224     return success();
225   }
226 };
227 
228 template <typename LhloOp>
229 class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
230  public:
231   using OpConversionPattern<LhloOp>::OpConversionPattern;
232 
matchAndRewrite(LhloOp lhlo_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const233   LogicalResult matchAndRewrite(
234       LhloOp lhlo_op, ArrayRef<Value> args,
235       ConversionPatternRewriter& rewriter) const final {
236     auto loc = lhlo_op.getLoc();
237     auto arg_type =
238         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
239     if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
240         (arg_type.getRank() != 0)) {
241       return failure();
242     }
243 
244     // Create two loads from the input.
245     auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
246     auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
247     // TODO(ravishankarm) : Move this method out of lmhlo namespace.
248     Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
249         lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
250         &rewriter);
251     rewriter.create<StoreOp>(loc, op_result, lhlo_op.out());
252     rewriter.eraseOp(lhlo_op);
253     return success();
254   }
255 };
256 
257 //===----------------------------------------------------------------------===//
258 // lmhlo.convolution conversion pattern.
259 //===----------------------------------------------------------------------===//
260 
261 /// Converts lmhlo.convolution operation to a linalg.conv op.
262 struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
263  public:
264   using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
265 
266   //  This code has been adapted from IREE's
267   //  (https://github.com/google/iree/) mhlo -> linalg conversion.
matchAndRewritemlir::__anon57cf929c0111::ConvToLinalgConverter268   LogicalResult matchAndRewrite(
269       lmhlo::ConvOp op, ArrayRef<Value> args,
270       ConversionPatternRewriter& rewriter) const final {
271     // Check validity of dimension information.
272     if (const mhlo::ConvDimensionNumbers& dimension_numbers =
273             op.dimension_numbers()) {
274       const int input_spatial_rank =
275           llvm::size(dimension_numbers.input_spatial_dimensions());
276       // The dimensions for input should follow the order of
277       // batch_count, spatial_dims..., input_feature_count.
278       if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
279           dimension_numbers.input_feature_dimension().getInt() !=
280               (input_spatial_rank + 1))
281         return failure();
282 
283       const int kernel_spatial_rank =
284           llvm::size(dimension_numbers.kernel_spatial_dimensions());
285       // The dimensions for filter should follow the order of
286       // spatial_dims..., input_feature_count, num_output_feature_count.
287       if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
288               kernel_spatial_rank ||
289           dimension_numbers.kernel_output_feature_dimension().getInt() !=
290               (kernel_spatial_rank + 1))
291         return failure();
292 
293       const int output_spatial_rank =
294           llvm::size(dimension_numbers.output_spatial_dimensions());
295       // The dimensions for output should follow the order of
296       // batch_count, spatial_dims.., output_feature_count.
297       if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
298           dimension_numbers.output_feature_dimension().getInt() !=
299               (output_spatial_rank + 1))
300         return failure();
301 
302       if (input_spatial_rank != output_spatial_rank ||
303           input_spatial_rank != kernel_spatial_rank)
304         return failure();
305 
306       auto input_spatial_dim =
307           dimension_numbers.input_spatial_dimensions().begin();
308       auto kernel_spatial_dim =
309           dimension_numbers.kernel_spatial_dimensions().begin();
310       auto output_spatial_dim =
311           dimension_numbers.output_spatial_dimensions().begin();
312       // Check if spatial dims are ordered correctly.
313       for (int i = 0; i < input_spatial_rank; ++i) {
314         const int dim = i + 1;
315         if ((*input_spatial_dim++).getZExtValue() != dim ||
316             (*output_spatial_dim++).getZExtValue() != dim ||
317             (*kernel_spatial_dim++).getZExtValue() != i)
318           return failure();
319       }
320     }
321 
322     // TODO: LHS dilation for deconvolution not supported yet.
323     // TODO(jurahul): Window reversal is not supported yet.
324     if (op.lhs_dilation() || op.hasWindowReversal()) {
325       return failure();
326     }
327 
328     llvm::SmallVector<Attribute, 4> strides;
329     if (auto window_strides = op.window_strides()) {
330       auto range = window_strides->getAttributeValues();
331       strides.assign(range.begin(), range.end());
332     }
333     auto strides_arg = ArrayAttr::get(op.getContext(), strides);
334 
335     llvm::SmallVector<Attribute, 2> dilation;
336     if (auto rhs_dilation = op.rhs_dilation()) {
337       auto range = rhs_dilation->getAttributeValues();
338       dilation.assign(range.begin(), range.end());
339     } else {
340       // Default dilation of 1.
341       dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
342     }
343     auto dilation_arg = ArrayAttr::get(op.getContext(), dilation);
344 
345     // Set padding only if it is non-zero.
346     DenseIntElementsAttr padding = op.paddingAttr();
347     if (!padding ||
348         !llvm::any_of(padding.getValues<APInt>(),
349                       [](APInt int_val) { return !int_val.isNullValue(); })) {
350       padding = nullptr;
351     }
352 
353     // The order of input and filter are switched with linalg.conv.
354     rewriter.replaceOpWithNewOp<linalg::ConvOp>(
355         op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
356     return success();
357   }
358 };
359 
360 /// Base class for lowering HLO operations that have one operand and one result,
361 /// and are semantically equivalent to a copy of the input to the output (like
362 /// transpose, some reshape, etc.). The derived classes need to provide a method
363 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
364 /// and the output.
365 template <typename Derived, typename OpTy, bool isLHLO = true>
366 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
367  public:
368   using OpConversionPattern<OpTy>::OpConversionPattern;
369 
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const370   LogicalResult matchAndRewrite(
371       OpTy op, ArrayRef<Value> args,
372       ConversionPatternRewriter& rewriter) const final {
373     if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
374     auto result_type = GetHloOpResultType<isLHLO>(op);
375 
376     SmallVector<AffineMap, 2> indexing_maps =
377         Derived::getIndexingMaps(op, &rewriter);
378     if (indexing_maps.empty()) return failure();
379 
380     auto nloops = result_type.getRank();
381     auto loc = op.getLoc();
382     // TODO(pifon): technically, the op itself could have size operands (e.g.
383     // broadcast into a dynamic dimension).Handle this case.
384     auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
385                             : ExtractDynamicSizes(rewriter, loc, args[0]);
386     auto linalg_op = rewriter.create<linalg::GenericOp>(
387         loc,
388         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
389         /*inputs=*/args.front(),
390         /*outputBuffers=*/
391         isLHLO
392             ? ValueRange{args.back()}
393             : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
394         indexing_maps, GetNParallelLoopsAttrs(nloops),
395         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
396           nested_builder.create<linalg::YieldOp>(loc, *args.begin());
397         });
398     rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
399     return success();
400   }
401 };
402 
403 /// Pattern to convert BroadcastOp to Linalg ops.
404 template <typename OpTy, bool isLHLO = true>
405 class BroadcastConverter
406     : public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
407                                      isLHLO> {
408  public:
409   using DataMovementOpConverter<BroadcastConverter, OpTy,
410                                 isLHLO>::DataMovementOpConverter;
411 
getIndexingMaps(OpTy broadcast_op,Builder * b)412   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
413                                                    Builder* b) {
414     ShapedType input_type =
415         broadcast_op.operand().getType().template cast<ShapedType>();
416     unsigned input_rank = input_type.getRank();
417     unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
418 
419     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
420     // the input's dimensions.
421     unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
422     SmallVector<AffineExpr, 4> input_dim_exprs;
423     input_dim_exprs.reserve(input_rank);
424     for (int i = 0; i < input_rank; ++i) {
425       input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
426     }
427 
428     AffineMap input_map;
429     MLIRContext* context = b->getContext();
430     if (input_dim_exprs.empty()) {
431       // The input is a scalar, i.e. this is a scalar broadcast op.
432       input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
433     } else {
434       input_map =
435           AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
436     }
437     return {input_map, b->getMultiDimIdentityMap(nloops)};
438   }
439 };
440 
441 class HloBroadcastInDimConverter
442     : public DataMovementOpConverter<HloBroadcastInDimConverter,
443                                      mhlo::BroadcastInDimOp, false> {
444  public:
445   using DataMovementOpConverter<HloBroadcastInDimConverter,
446                                 mhlo::BroadcastInDimOp,
447                                 false>::DataMovementOpConverter;
448 
getIndexingMaps(mhlo::BroadcastInDimOp broadcast_op,Builder * b)449   static SmallVector<AffineMap, 2> getIndexingMaps(
450       mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
451     auto result_type = GetHloOpResultType<false>(broadcast_op);
452     auto operand_type =
453         broadcast_op.operand().getType().template cast<ShapedType>();
454     unsigned nloops = result_type.getRank();
455 
456     // The input is a scalar, i.e. this is a scalar broadcast op.
457     if (operand_type.getRank() == 0) {
458       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
459               b->getMultiDimIdentityMap(nloops)};
460     }
461 
462     auto operand_shape = operand_type.getShape();
463     SmallVector<AffineExpr, 4> dim_exprs;
464     dim_exprs.reserve(nloops);
465 
466     if (broadcast_op.broadcast_dimensions()) {
467       for (const auto& broadcastDim :
468            enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
469         int size = broadcastDim.value().getSExtValue();
470         bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
471                                 result_type.getShape()[size] != 1;
472         dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
473                                              : b->getAffineDimExpr(size));
474       }
475     }
476     return {
477         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
478         b->getMultiDimIdentityMap(nloops)};
479   }
480 };
481 
482 class HloDynamicBroadcastInDimConverter
483     : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
484  public:
485   using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
486 
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const487   LogicalResult matchAndRewrite(
488       mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
489       ConversionPatternRewriter& rewriter) const final {
490     // Convert only if the producer is an HLO constant. Ideally the pattern
491     // (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
492     // to an Tensor-dialect op similar to TF ConstantLikeOp.
493     if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
494 
495     mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
496     Value operand = adaptor.operand();
497     auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
498     if (!operand_type || operand_type.getRank() != 0) return failure();
499 
500     Value shape = adaptor.output_dimensions();
501     auto shape_type = shape.getType().cast<RankedTensorType>();
502     int64_t result_rank = shape_type.getDimSize(0);
503 
504     SmallVector<Value, 2> dyn_dims;
505     Location loc = op.getLoc();
506     for (int i = 0; i < result_rank; ++i) {
507       Value index = rewriter.create<ConstantIndexOp>(loc, i);
508       dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
509     }
510     auto result_type = op.getType().dyn_cast<RankedTensorType>();
511     if (!result_type) return failure();
512 
513     int64_t nloops = result_type.getRank();
514     Value init = rewriter.create<linalg::InitTensorOp>(
515         loc, dyn_dims, result_type.getShape(), result_type.getElementType());
516     Operation* generic = rewriter.create<linalg::GenericOp>(
517         loc, TypeRange{init.getType()}, ValueRange{operand},
518         /*outputBuffers=*/ValueRange{init},
519         llvm::makeArrayRef(
520             {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
521                             rewriter.getContext()),
522              rewriter.getMultiDimIdentityMap(nloops)}),
523         GetNParallelLoopsAttrs(nloops),
524         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
525           nested_builder.create<linalg::YieldOp>(loc, *args.begin());
526         });
527     rewriter.replaceOp(op, generic->getResults());
528     return success();
529   }
530 };
531 
532 class LhloBroadcastInDimConverter
533     : public OpConversionPattern<lmhlo::BroadcastInDimOp> {
534  public:
535   using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
536 
matchAndRewrite(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const537   LogicalResult matchAndRewrite(
538       lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
539       ConversionPatternRewriter& rewriter) const final {
540     lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
541     auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
542     auto result_shape = result_type.getShape();
543 
544     auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
545 
546     Value operand = std::get<0>(operand_and_dims);
547     auto broadcast_dims = std::get<1>(operand_and_dims);
548 
549     auto loc = op.getLoc();
550     auto nloops = result_type.getRank();
551     auto operand_type = operand.getType().cast<MemRefType>();
552 
553     // For a degenerate case, i.e. broadcasting with expansion of
554     // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
555     // Instead the value is loaded and used directly in `linalg.yield`.
556     if (operand_type.getRank() == 1 &&
557         operand_type.getDimSize(0) <
558             result_type.getDimSize(broadcast_dims.front())) {
559       Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
560       Value val =
561           rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
562       rewriter.create<linalg::GenericOp>(
563           loc, /*inputs=*/ValueRange{},
564           /*outputBuffers=*/ValueRange{operand_adaptor.output()},
565           llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
566           GetNParallelLoopsAttrs(nloops),
567           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
568             nested_builder.create<linalg::YieldOp>(loc, val);
569           });
570 
571     } else {
572       auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
573                                            operand_type, &rewriter);
574       rewriter.create<linalg::GenericOp>(
575           loc, /*inputs=*/ValueRange{operand},
576           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
577           GetNParallelLoopsAttrs(nloops),
578           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
579             nested_builder.create<linalg::YieldOp>(loc, *args.begin());
580           });
581     }
582     rewriter.replaceOp(op, llvm::None);
583     return success();
584   }
585 
586   // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
InsertReshapeIfNecessary(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const587   std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
588       lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
589       ConversionPatternRewriter& rewriter) const {
590     lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
591     Value operand = operand_adaptor.operand();
592     auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
593     auto operand_shape = operand_type.getShape();
594 
595     Value result = operand_adaptor.output();
596     auto result_type = result.getType().cast<MemRefType>();
597     auto result_shape = result_type.getShape();
598 
599     SmallVector<int64_t, 2> operand_strides;
600     int64_t operand_offset;
601     if (failed(getStridesAndOffset(operand_type, operand_strides,
602                                    operand_offset))) {
603       op.emitOpError() << "Failed to get offset and strides.";
604     }
605 
606     SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
607     SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
608     linalg::ReassociationIndices collapsed_dims;
609     for (const auto& item :
610          enumerate(op.broadcast_dimensions().getIntValues())) {
611       size_t index = item.index();
612       int dim = item.value().getSExtValue();
613 
614       collapsed_dims.push_back(index);
615 
616       bool expansion_needed =
617           operand_shape[index] == 1 && result_shape[dim] != 1;
618       if (expansion_needed) {
619         continue;
620       }
621       new_shape.push_back(operand_shape[index]);
622       new_strides.push_back(operand_strides[index]);
623       broadcast_dims.push_back(dim);
624 
625       collapsed_dims_list.push_back(collapsed_dims);
626       collapsed_dims.clear();
627     }
628     // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
629     // and all dimensions need expansion. Such memref will be reshaped to a 1D
630     // memref with a single element. New shape and strides needs to be updated
631     // accordingly.
632     if (collapsed_dims_list.empty()) {
633       collapsed_dims_list.push_back({});
634       new_shape.push_back(1);
635       new_strides.push_back(1);
636       broadcast_dims.push_back(0);
637     }
638     for (const auto& dims : collapsed_dims) {
639       collapsed_dims_list.back().push_back(dims);
640     }
641 
642     // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
643     // reduced.
644     if (new_shape.size() < operand_shape.size()) {
645       auto new_memref_type = MemRefType::get(
646           new_shape, operand_type.getElementType(),
647           makeStridedLinearLayoutMap(new_strides, operand_offset,
648                                      rewriter.getContext()));
649       operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
650                                                    operand_adaptor.operand(),
651                                                    collapsed_dims_list);
652     }
653     return std::make_pair(operand, broadcast_dims);
654   }
655 
getIndexingMaps(lmhlo::BroadcastInDimOp op,ArrayRef<int64_t> broadcast_dims,ArrayRef<int64_t> result_shape,MemRefType operand_type,Builder * b) const656   SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
657                                             ArrayRef<int64_t> broadcast_dims,
658                                             ArrayRef<int64_t> result_shape,
659                                             MemRefType operand_type,
660                                             Builder* b) const {
661     unsigned nloops = result_shape.size();
662 
663     // The input is a scalar, i.e. this is a scalar broadcast op.
664     if (operand_type.getRank() == 0) {
665       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
666               b->getMultiDimIdentityMap(nloops)};
667     }
668 
669     auto operand_shape = operand_type.getShape();
670     SmallVector<AffineExpr, 4> dim_exprs;
671     dim_exprs.reserve(nloops);
672 
673     for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
674       int size = broadcast_dim.value();
675       bool expansion_needed =
676           operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
677       if (expansion_needed) {
678         op.emitOpError(
679             "BroadcastInDimOp lowering to Linalg does not support size-1 "
680             "dimensions expansion.");
681       }
682       dim_exprs.push_back(b->getAffineDimExpr(size));
683     }
684     return {
685         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
686         b->getMultiDimIdentityMap(nloops)};
687   }
688 };
689 
690 template <typename OpTy, bool isLHLO = true>
691 class TransposeConverter
692     : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
693                                      isLHLO> {
694  public:
695   using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
696                                 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)697   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
698     auto result_type =
699         GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
700     auto nloops = result_type.getRank();
701     SmallVector<AffineExpr, 2> input_exprs;
702     input_exprs.resize(result_type.getRank());
703     for (auto permutation : llvm::enumerate(op.permutation())) {
704       input_exprs[permutation.value().getZExtValue()] =
705           b->getAffineDimExpr(permutation.index());
706     }
707     return {
708         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
709         b->getMultiDimIdentityMap(nloops)};
710   }
711 };
712 
713 // Converts reshape ops that can be proven to be either a collapse of dimensions
714 // or expansion of dimensions of the operand.
715 template <typename OpTy, bool isLHLO = true>
716 class ReshapeOpConverter : public OpConversionPattern<OpTy> {
717  public:
718   using OpConversionPattern<OpTy>::OpConversionPattern;
719 
matchAndRewrite(OpTy reshape_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const720   LogicalResult matchAndRewrite(
721       OpTy reshape_op, ArrayRef<Value> args,
722       ConversionPatternRewriter& rewriter) const final {
723     if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
724       return failure();
725     typename OpTy::Adaptor operands(args);
726     ShapedType operand_type =
727         operands.operand().getType().template cast<ShapedType>();
728     ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
729 
730     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
731       return failure();
732 
733     // Compute the reassociation maps for the linalg operation.
734     ArrayRef<int64_t> src_shape =
735         (operand_type.getRank() > result_type.getRank()
736              ? operand_type.getShape()
737              : result_type.getShape());
738     ArrayRef<int64_t> dst_shape =
739         (operand_type.getRank() > result_type.getRank()
740              ? result_type.getShape()
741              : operand_type.getShape());
742     unsigned curr_src_dim = 0, curr_dst_dim = 0;
743     SmallVector<linalg::ReassociationExprs, 4> reassociation_map(
744         dst_shape.size());
745 
746     // First scan all dimensions in the source shapes to see whether we have a
747     // perfect case where consecutive dimensions in source are collapsed. For
748     // such case we can just generate one single linalg.reshape.
749     bool is_collapsing_source = true;
750     while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
751       int64_t dst_size = dst_shape[curr_dst_dim];
752       int64_t src_size = src_shape[curr_src_dim];
753       while (src_size < dst_size && curr_src_dim < src_shape.size()) {
754         reassociation_map[curr_dst_dim].push_back(
755             rewriter.getAffineDimExpr(curr_src_dim++));
756         src_size *= src_shape[curr_src_dim];
757       }
758       if (src_size == dst_size) {
759         reassociation_map[curr_dst_dim].push_back(
760             rewriter.getAffineDimExpr(curr_src_dim++));
761         // If the next dim in dst_shape is not 1, treat subsequent dims in
762         // src_shape which are 1 to be collapsed.
763         if (curr_dst_dim == dst_shape.size() - 1 ||
764             dst_shape[curr_dst_dim + 1] != 1) {
765           while (curr_src_dim < src_shape.size() &&
766                  src_shape[curr_src_dim] == 1) {
767             reassociation_map[curr_dst_dim].push_back(
768                 rewriter.getAffineDimExpr(curr_src_dim++));
769           }
770         }
771       } else {
772         is_collapsing_source = false;
773         break;
774       }
775       curr_dst_dim++;
776     }
777     if (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size())
778       is_collapsing_source = false;
779 
780     // Otherwise, we need to first reduce all source dimensions into one and
781     // then expand to the destination dimensions.
782     if (!is_collapsing_source) {
783       auto get_identity_exprs = [&rewriter](int n) {
784         SmallVector<AffineExpr, 4> exprs;
785         for (int i = 0; i < n; ++i)
786           exprs.push_back(rewriter.getAffineDimExpr(i));
787         return exprs;
788       };
789       Location loc = reshape_op.getLoc();
790       int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
791                                             1, std::multiplies<int64_t>());
792       auto elem_type = operand_type.getElementType();
793       SmallVector<linalg::ReassociationExprs, 4> collapsing_map = {
794           // Use operand_type here because we need to collapse all operands
795           // dimensions.
796           get_identity_exprs(operand_type.getShape().size())};
797       SmallVector<linalg::ReassociationExprs, 4> expanding_map = {
798           // Use result_type here because we need to expand to all result
799           // dimensions.
800           get_identity_exprs(result_type.getShape().size())};
801 
802       if (isLHLO) {
803         auto collapsed_type = MemRefType::get({total_elems}, elem_type);
804         Value collapsed_op = rewriter.create<linalg::ReshapeOp>(
805             loc, collapsed_type, args[0], collapsing_map);
806         Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
807             loc, result_type, collapsed_op, expanding_map);
808         rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
809                                                     args[1]);
810       } else {
811         auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
812         Value collapsed_op = rewriter.create<linalg::TensorReshapeOp>(
813             loc, collapsed_type, args[0], collapsing_map);
814         rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
815             reshape_op, result_type, collapsed_op, expanding_map);
816       }
817       return success();
818     }
819 
820     if (isLHLO) {
821       Value reshape_buffer = rewriter.create<linalg::ReshapeOp>(
822           reshape_op.getLoc(), result_type, args[0], reassociation_map);
823       rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
824                                                   args[1]);
825     } else {
826       rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
827           reshape_op, result_type, args[0], reassociation_map);
828     }
829     return success();
830   }
831 };
832 
833 template <typename OpTy, bool isLHLO = true>
834 class IotaConverter : public OpConversionPattern<OpTy> {
835  public:
836   using OpConversionPattern<OpTy>::OpConversionPattern;
837 
matchAndRewrite(OpTy iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const838   LogicalResult matchAndRewrite(
839       OpTy iota_op, ArrayRef<Value> args,
840       ConversionPatternRewriter& rewriter) const final {
841     ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
842     if (!result_shaped_type) return failure();
843 
844     auto result_element_type = result_shaped_type.getElementType();
845     if (!result_element_type.isSignlessIntOrFloat()) return failure();
846 
847     // Construct the indexing maps needed for linalg.generic ops.
848     unsigned nloops = result_shaped_type.getRank();
849 
850     Location loc = iota_op.getLoc();
851     auto dyn_sizes = isLHLO
852                          ? SmallVector<Value, 2>()
853                          : ExtractDynamicSizes(rewriter, loc,
854                                                GetResultValue<isLHLO>(iota_op));
855     auto linalg_op = rewriter.create<linalg::IndexedGenericOp>(
856         loc,
857         /*resultTensorTypes=*/
858         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
859         /*inputs=*/ValueRange{},
860         /*outputBuffers=*/
861         isLHLO ? ValueRange{args}
862                : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
863                                           dyn_sizes)},
864         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
865         GetNParallelLoopsAttrs(nloops),
866         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange ivs,
867             ValueRange args) {
868           Value cast_op = nested_builder.create<IndexCastOp>(
869               nested_loc, ivs[iota_op.iota_dimension()],
870               nested_builder.getIntegerType(
871                   result_element_type.getIntOrFloatBitWidth()));
872           if (result_element_type.template isa<FloatType>()) {
873             cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
874                                                       result_element_type);
875           }
876           nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
877         });
878     if (isLHLO)
879       rewriter.replaceOp(iota_op, llvm::None);
880     else
881       rewriter.replaceOp(iota_op, linalg_op.result_tensors());
882     return success();
883   }
884 };
885 
886 template <typename OpTy>
887 class ConstConverter : public OpConversionPattern<OpTy> {
888  public:
889   using OpConversionPattern<OpTy>::OpConversionPattern;
890 
matchAndRewrite(OpTy const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const891   LogicalResult matchAndRewrite(
892       OpTy const_op, ArrayRef<Value> /*args*/,
893       ConversionPatternRewriter& rewriter) const final {
894     Location loc = const_op.getLoc();
895     auto value_attr = const_op.value().template cast<DenseElementsAttr>();
896     if (value_attr.getType().getRank() != 0) return failure();
897     ReplaceConstOp(loc, const_op, value_attr, rewriter);
898     return success();
899   }
900 
901  private:
ReplaceConstOp(Location loc,mhlo::ConstOp op,DenseElementsAttr value_attr,ConversionPatternRewriter & rewriter) const902   void ReplaceConstOp(Location loc, mhlo::ConstOp op,
903                       DenseElementsAttr value_attr,
904                       ConversionPatternRewriter& rewriter) const {
905     Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr);
906     rewriter.replaceOp(op, {std_tensor_const});
907   }
ReplaceConstOp(Location loc,lmhlo::ConstOp op,DenseElementsAttr value_attr,ConversionPatternRewriter & rewriter) const908   void ReplaceConstOp(Location loc, lmhlo::ConstOp op,
909                       DenseElementsAttr value_attr,
910                       ConversionPatternRewriter& rewriter) const {
911     Value std_scalar_const =
912         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
913     rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(),
914                                          llvm::None);
915     rewriter.eraseOp(op);
916   }
917 };
918 
919 class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
920  public:
921   using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
922 
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const923   LogicalResult matchAndRewrite(
924       lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
925       ConversionPatternRewriter& rewriter) const final {
926     auto loc = reduce_op.getLoc();
927     lmhlo::ReduceOp::Adaptor adaptor(args);
928     auto operand_shape =
929         adaptor.operands()[0].getType().template dyn_cast<ShapedType>();
930     if (!operand_shape || !operand_shape.hasRank()) {
931       emitError(loc, "lhlo to linalg conversion expects known-rank args");
932       return failure();
933     }
934 
935     // First fill the output buffer with the init value.
936     Value init_value = rewriter.create<LoadOp>(loc, adaptor.init_values()[0]);
937     rewriter.create<linalg::FillOp>(loc, adaptor.out()[0], init_value);
938 
939     DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
940     SmallVector<int, 4> reduction_dims;
941     for (const auto& dim : dimensions_attr.getIntValues()) {
942       reduction_dims.push_back(dim.getSExtValue());
943     }
944 
945     SmallVector<AffineExpr, 2> src_exprs;
946     SmallVector<AffineExpr, 2> dst_exprs;
947     SmallVector<StringRef, 4> types;
948     for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
949       bool is_reduced = llvm::is_contained(reduction_dims, i);
950       types.push_back(is_reduced ? getReductionIteratorTypeName()
951                                  : getParallelIteratorTypeName());
952 
953       src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
954       if (!is_reduced) {
955         dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
956       }
957     }
958 
959     auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
960 
961     auto linalg_op = rewriter.create<linalg::GenericOp>(
962         loc, /*resultTensorTypes=*/ArrayRef<Type>{},
963         /*inputs=*/adaptor.operands(), /*outputBuffers=*/adaptor.out(), maps,
964         types);
965     rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
966                                 linalg_op.region().end());
967     {
968       OpBuilder::InsertionGuard region_guard(rewriter);
969       Block* block = linalg_op.getBody();
970       rewriter.setInsertionPoint(&block->front());
971 
972       // The incoming region is operating on buffers, while linalg.generic
973       // expects scalar SSA values. Add some allocs around the original op to
974       // make it compatible.
975       auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
976       Value alloc_a = rewriter.create<AllocaOp>(loc, arg_type);
977       Value alloc_b = rewriter.create<AllocaOp>(loc, arg_type);
978       Value alloc_res = rewriter.create<AllocaOp>(loc, arg_type);
979 
980       // Now turn the existing signature
981       //   (memref<X>, memref<X>, memref<X>) -> ()
982       // into
983       //   (X, X) -> X
984       TypeConverter::SignatureConversion signature_converter(3);
985       signature_converter.remapInput(0, alloc_a);
986       signature_converter.remapInput(1, alloc_b);
987       signature_converter.remapInput(2, alloc_res);
988       signature_converter.addInputs(
989           {arg_type.getElementType(), arg_type.getElementType()});
990       Block* entry_block = rewriter.applySignatureConversion(
991           &linalg_op.region(), signature_converter);
992 
993       // Store the arguments into the newly allocated buffers.
994       rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
995       rewriter.create<StoreOp>(loc, entry_block->getArgument(0), alloc_a);
996       rewriter.create<StoreOp>(loc, entry_block->getArgument(1), alloc_b);
997       rewriter.replaceOp(entry_block->getTerminator(), {});
998 
999       // Load & yield the result.
1000       rewriter.setInsertionPointToEnd(entry_block);
1001       auto load_res = rewriter.create<LoadOp>(loc, alloc_res);
1002       rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
1003     }
1004 
1005     rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
1006     return success();
1007   }
1008 };
1009 
1010 // TODO(b/156787842): Support the lowering for dynamic shapes.
1011 template <typename OpTy, bool isLHLO = true>
1012 class ReverseConverter
1013     : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1014                                      isLHLO> {
1015  public:
1016   using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1017                                 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1018   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1019     auto result_type =
1020         GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1021     auto nloops = result_type.getRank();
1022     SmallVector<AffineExpr, 2> input_exprs;
1023     input_exprs.reserve(nloops);
1024     for (int i = 0; i < nloops; ++i)
1025       input_exprs.push_back(b->getAffineDimExpr(i));
1026     for (auto dim : op.dimensions()) {
1027       int i = dim.getZExtValue();
1028       if (result_type.isDynamicDim(i)) return {};
1029       int n = result_type.getShape()[i];
1030       input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
1031     }
1032     return {
1033         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1034         b->getMultiDimIdentityMap(nloops)};
1035   }
1036 };
1037 
1038 class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
1039  public:
1040   using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
1041 
matchAndRewrite(lmhlo::SliceOp slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1042   LogicalResult matchAndRewrite(
1043       lmhlo::SliceOp slice_op, ArrayRef<Value> args,
1044       ConversionPatternRewriter& rewriter) const final {
1045     auto loc = slice_op.getLoc();
1046     auto arg_type =
1047         slice_op.getOperand(0).getType().template dyn_cast<ShapedType>();
1048     if (!arg_type || !arg_type.hasRank()) {
1049       emitError(loc, "lhlo to linalg conversion expects known-rank args");
1050       return failure();
1051     }
1052 
1053     SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1054     for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
1055       offsets.push_back(rewriter.getI64IntegerAttr(
1056           slice_op.start_indices().getValue<int64_t>(i)));
1057       sizes.push_back(rewriter.getI64IntegerAttr(
1058           slice_op.limit_indices().getValue<int64_t>(i) -
1059           slice_op.start_indices().getValue<int64_t>(i)));
1060       strides.push_back(
1061           rewriter.getI64IntegerAttr(slice_op.strides().getValue<int64_t>(i)));
1062     }
1063     auto linalg_slice = rewriter.create<SubViewOp>(loc, slice_op.getOperand(0),
1064                                                    offsets, sizes, strides);
1065     rewriter.create<linalg::CopyOp>(loc, linalg_slice, slice_op.getOperand(1));
1066     rewriter.eraseOp(slice_op);
1067     return success();
1068   }
1069 };
1070 
1071 enum class DotOperationType {
1072   kVectorDot = 0,
1073   kMatrixVector = 1,
1074   kMatrixMatrix = 2,
1075   kUnsupported = 3
1076 };
1077 
GetDotOperationType(mhlo::DotOp dot_op)1078 DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
1079   ArrayRef<int64_t> lhs_shape =
1080       dot_op.lhs().getType().cast<ShapedType>().getShape();
1081   ArrayRef<int64_t> rhs_shape =
1082       dot_op.rhs().getType().cast<ShapedType>().getShape();
1083   auto shape_matches = [](int64_t a, int64_t b) {
1084     return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1085            a == b;
1086   };
1087   if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
1088       shape_matches(lhs_shape[0], rhs_shape[0])) {
1089     return DotOperationType::kVectorDot;
1090   }
1091   if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
1092       shape_matches(lhs_shape[1], rhs_shape[0])) {
1093     return DotOperationType::kMatrixVector;
1094   }
1095   if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
1096       shape_matches(lhs_shape[1], rhs_shape[0])) {
1097     return DotOperationType::kMatrixMatrix;
1098   }
1099   return DotOperationType::kUnsupported;
1100 }
1101 
GetDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1102 SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1103                                                  Value lhs, Value rhs,
1104                                                  DotOperationType type) {
1105   SmallVector<Value, 2> dyn_shape;
1106   switch (type) {
1107     case DotOperationType::kMatrixMatrix: {
1108       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1109         dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1110       if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1111         dyn_shape.push_back(b.create<DimOp>(loc, rhs, 1));
1112       break;
1113     }
1114     case DotOperationType::kMatrixVector: {
1115       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1116         dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1117       break;
1118     }
1119     case DotOperationType::kVectorDot:
1120     case DotOperationType::kUnsupported:
1121     default: {
1122       break;
1123     }
1124   }
1125   return dyn_shape;
1126 }
1127 
1128 class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
1129  public:
1130   using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1131   LogicalResult matchAndRewrite(
1132       mhlo::DotOp op, ArrayRef<Value> args,
1133       ConversionPatternRewriter& rewriter) const final {
1134     if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1135       return failure();
1136     }
1137     Location loc = op.getLoc();
1138     mhlo::DotOp::Adaptor adaptor(args);
1139     Type result_type = op.getResult().getType();
1140     auto shaped_type = result_type.cast<ShapedType>();
1141     DotOperationType op_type = GetDotOperationType(op);
1142     auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
1143     Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1144     SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
1145         rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1146     auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1147     Value zero_tensor =
1148         rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
1149     linalg::LinalgOp linalg_op;
1150     switch (op_type) {
1151       case DotOperationType::kMatrixMatrix: {
1152         linalg_op = rewriter.create<linalg::MatmulOp>(
1153             loc, TypeRange{result_type},
1154             ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1155         break;
1156       }
1157       case DotOperationType::kMatrixVector: {
1158         linalg_op = rewriter.create<linalg::MatvecOp>(
1159             loc, TypeRange{result_type},
1160             ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1161         break;
1162       }
1163       case DotOperationType::kVectorDot: {
1164         linalg_op = rewriter.create<linalg::DotOp>(
1165             loc, TypeRange{result_type},
1166             ValueRange{adaptor.lhs(), adaptor.rhs()}, ValueRange{zero_tensor});
1167         break;
1168       }
1169       case DotOperationType::kUnsupported:
1170       default: {
1171         return op.emitError("unsupported dot operation type");
1172       }
1173     }
1174     rewriter.replaceOp(op, linalg_op->getResults());
1175     return success();
1176   }
1177 };
1178 
GetDotGeneralOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,ShapedType result_type)1179 SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
1180     OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
1181   SmallVector<Value, 8> dyn_shape;
1182   if (result_type.isDynamicDim(0))
1183     dyn_shape.push_back(b.create<DimOp>(loc, lhs, 0));
1184   if (result_type.isDynamicDim(1))
1185     dyn_shape.push_back(b.create<DimOp>(loc, lhs, 1));
1186   if (result_type.isDynamicDim(2))
1187     dyn_shape.push_back(b.create<DimOp>(loc, rhs, 2));
1188   return dyn_shape;
1189 }
1190 
1191 class DotGeneralOpOnTensorsConversion
1192     : public OpConversionPattern<mhlo::DotGeneralOp> {
1193  public:
1194   using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1195   LogicalResult matchAndRewrite(
1196       mhlo::DotGeneralOp op, ArrayRef<Value> args,
1197       ConversionPatternRewriter& rewriter) const final {
1198     if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1199       return failure();
1200     }
1201     mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
1202     auto lhs_bathcing_dims =
1203         Extract1DVector(dim_numbers.lhs_batching_dimensions());
1204     auto rhs_bathcing_dims =
1205         Extract1DVector(dim_numbers.rhs_batching_dimensions());
1206     auto lhs_contracting_dims =
1207         Extract1DVector(dim_numbers.lhs_contracting_dimensions());
1208     auto rhs_contracting_dims =
1209         Extract1DVector(dim_numbers.rhs_contracting_dimensions());
1210     if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
1211       return rewriter.notifyMatchFailure(
1212           op, "expected lhs batching dimensions exactly {0}");
1213     }
1214     if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
1215       return rewriter.notifyMatchFailure(
1216           op, "expected rhs batching dimensions exactly {0}");
1217     }
1218     if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
1219       return rewriter.notifyMatchFailure(
1220           op, "expected lhs contracting dimensions exactly {2}");
1221     }
1222     if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
1223       return rewriter.notifyMatchFailure(
1224           op, "expected rhs contracting dimensions exactly {1}");
1225     }
1226     Location loc = op.getLoc();
1227     mhlo::DotGeneralOp::Adaptor adaptor(args);
1228     Type result_type = op.getResult().getType();
1229     auto shaped_type = result_type.cast<ShapedType>();
1230     SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
1231         rewriter, loc, adaptor.lhs(), adaptor.rhs(), shaped_type);
1232     auto zero_attr = rewriter.getZeroAttr(shaped_type.getElementType());
1233     Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1234     auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1235     Value zero_tensor =
1236         rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
1237     auto linalg_op = rewriter.create<linalg::BatchMatmulOp>(
1238         loc, /*resultTensorTypes=*/TypeRange{result_type},
1239         /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1240         /*outputBuffers=*/ValueRange{zero_tensor});
1241     rewriter.replaceOp(op, linalg_op.getResults());
1242     return success();
1243   }
1244 };
1245 
1246 template <typename OpTy>
1247 struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
1248   using OpConversionPattern<OpTy>::OpConversionPattern;
matchAndRewritemlir::__anon57cf929c0111::ReduceRegionXLAOpConversion1249   LogicalResult matchAndRewrite(
1250       OpTy op, ArrayRef<Value> args,
1251       ConversionPatternRewriter& rewriter) const final {
1252     // Only convert the body of reduction ops to std ops.
1253     auto parent_op = op.getOperation()->getParentRegion()->getParentOp();
1254     if (!isa<mhlo::ReduceOp, linalg::GenericOp, linalg::IndexedGenericOp>(
1255             parent_op)) {
1256       return failure();
1257     }
1258     if (!op.getResult().getType().template isa<TensorType>()) return failure();
1259     if (llvm::all_of(args, [](Value arg) {
1260           return arg.getType().template isa<TensorType>();
1261         })) {
1262       return failure();
1263     }
1264     Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(op, args[0].getType(),
1265                                                         args, &rewriter);
1266     rewriter.replaceOp(op, result);
1267     return success();
1268   }
1269 };
1270 
GetReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType result_type,ArrayRef<int64_t> reduction_dims)1271 SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
1272     OpBuilder& b, Location loc, Value arg, ShapedType result_type,
1273     ArrayRef<int64_t> reduction_dims) {
1274   llvm::SmallSetVector<int, 4> s;
1275   for (auto dim : reduction_dims) s.insert(dim);
1276 
1277   SmallVector<unsigned, 4> parallel_dims;
1278   SmallVector<Value, 8> dyn_shape;
1279   int rank = arg.getType().cast<RankedTensorType>().getRank();
1280   for (int i = 0, j = 0; i < rank; ++i) {
1281     if (s.count(i)) continue;
1282     if (!result_type.isDynamicDim(j++)) continue;
1283     dyn_shape.push_back(b.create<DimOp>(loc, arg, i));
1284   }
1285 
1286   return dyn_shape;
1287 }
1288 
1289 class ReduceRegionReturnOpConversion
1290     : public OpConversionPattern<mhlo::ReturnOp> {
1291  public:
1292   using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1293   LogicalResult matchAndRewrite(
1294       mhlo::ReturnOp op, ArrayRef<Value> args,
1295       ConversionPatternRewriter& rewriter) const final {
1296     rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
1297     return success();
1298   }
1299 };
1300 
1301 class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
1302  public:
1303   using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1304   LogicalResult matchAndRewrite(
1305       mhlo::ReduceOp op, ArrayRef<Value> args,
1306       ConversionPatternRewriter& rewriter) const final {
1307     Location loc = op.getLoc();
1308     mhlo::ReduceOp::Adaptor adaptor(args);
1309     if (op.getNumOperands() != 2) {
1310       return op.emitError("expects exactly two operands");
1311     }
1312     Value src = adaptor.operands()[0];
1313     auto src_type = src.getType().cast<ShapedType>();
1314     int src_rank = src_type.getRank();
1315     if (!src_rank) {
1316       return rewriter.notifyMatchFailure(op, "expects known-rank args");
1317     }
1318 
1319     // Check if init_value is constant. If so, inline the value into the region.
1320     Value init_value = adaptor.init_values()[0];
1321     Attribute init_const_val = GetInitValueAsConst(init_value);
1322     if (init_const_val) {
1323       init_value = rewriter.create<ConstantOp>(
1324           init_value.getDefiningOp()->getLoc(), init_const_val);
1325     } else {
1326       init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
1327     }
1328 
1329     // Prepare indexing maps for linalg generic op. The elements are for src and
1330     // dst. Transpose `src` to make the reduction loops be the innermost,
1331     // because it's easier to fully utilize processors.
1332     SmallVector<AffineMap, 3> indexing_maps;
1333     SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
1334     indexing_maps.emplace_back(GetTransposeMapForReduction(
1335         rewriter.getContext(), src_rank, reduction_dims));
1336 
1337     // The indexing map of `dst` should drop the reduction loops. Since the
1338     // reduction loops now are all in the innermost, drops
1339     // `reduction_dims.size()` dimensions. We don't need an inverse permutation
1340     // here because they are the same.
1341     SmallVector<AffineExpr, 4> exprs;
1342     for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
1343       exprs.push_back(rewriter.getAffineDimExpr(i));
1344     indexing_maps.emplace_back(AffineMap::get(src_rank, /*symbolCount=*/0,
1345                                               exprs, rewriter.getContext()));
1346 
1347     SmallVector<Value, 2> inputs = {adaptor.operands()[0]};
1348     Type result_type = op.getResult(0).getType();
1349     auto shaped_type = result_type.cast<ShapedType>();
1350     SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
1351         rewriter, loc, adaptor.operands()[0], result_type.cast<ShapedType>(),
1352         reduction_dims);
1353     auto init_tensor = GetInitTensor(rewriter, loc, shaped_type, dyn_shape);
1354     Value filled_tensor =
1355         rewriter.create<linalg::FillOp>(loc, init_tensor, init_value)
1356             .getResult(0);
1357 
1358     auto linalg_op = rewriter.create<linalg::GenericOp>(
1359         loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
1360         /*outputBuffers=*/ValueRange{filled_tensor}, indexing_maps,
1361         GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
1362 
1363     // Convert the signature of the body. The reduce op region apply function
1364     // has a signature (lhs, rhs) -> output, all of the same tensor type t.
1365     // This is converted to a function with the same signature but with
1366     // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
1367     // be converted to "(f32, f32, f32)".
1368     Region& region = linalg_op.region();
1369     rewriter.inlineRegionBefore(op.body(), region, region.end());
1370     TypeConverter::SignatureConversion signatureConverter(2);
1371     signatureConverter.addInputs(0, src_type.getElementType());
1372     signatureConverter.addInputs(1, src_type.getElementType());
1373     rewriter.applySignatureConversion(&region, signatureConverter);
1374     rewriter.replaceOp(op, linalg_op.getResults());
1375     return success();
1376   }
1377 };
1378 
populateLHLOToLinalgConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)1379 void populateLHLOToLinalgConversionPattern(MLIRContext* context,
1380                                            OwningRewritePatternList* patterns) {
1381   // clang-format off
1382   patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
1383                    ConstConverter<lmhlo::ConstOp>,
1384                    ConvToLinalgConverter,
1385                    IotaConverter<lmhlo::IotaOp>,
1386                    LhloBroadcastInDimConverter,
1387                    PointwiseToLinalgConverter<lmhlo::AbsOp>,
1388                    PointwiseToLinalgConverter<lmhlo::AddOp>,
1389                    PointwiseToLinalgConverter<lmhlo::AndOp>,
1390                    PointwiseToLinalgConverter<lmhlo::Atan2Op>,
1391                    PointwiseToLinalgConverter<lmhlo::CeilOp>,
1392                    PointwiseToLinalgConverter<lmhlo::ClampOp>,
1393                    PointwiseToLinalgConverter<lmhlo::CompareOp>,
1394                    PointwiseToLinalgConverter<lmhlo::ComplexOp>,
1395                    PointwiseToLinalgConverter<lmhlo::ConvertOp>,
1396                    // TODO(ataei): Remove this pattern, CopyOp is folded away.
1397                    PointwiseToLinalgConverter<lmhlo::CopyOp>,
1398                    PointwiseToLinalgConverter<lmhlo::CosOp>,
1399                    PointwiseToLinalgConverter<lmhlo::DivOp>,
1400                    PointwiseToLinalgConverter<lmhlo::ExpOp>,
1401                    PointwiseToLinalgConverter<lmhlo::Expm1Op>,
1402                    PointwiseToLinalgConverter<lmhlo::FloorOp>,
1403                    PointwiseToLinalgConverter<lmhlo::ImagOp>,
1404                    PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
1405                    PointwiseToLinalgConverter<lmhlo::LogOp>,
1406                    PointwiseToLinalgConverter<lmhlo::LogisticOp>,
1407                    PointwiseToLinalgConverter<lmhlo::Log1pOp>,
1408                    PointwiseToLinalgConverter<lmhlo::MaxOp>,
1409                    PointwiseToLinalgConverter<lmhlo::MinOp>,
1410                    PointwiseToLinalgConverter<lmhlo::MulOp>,
1411                    PointwiseToLinalgConverter<lmhlo::NegOp>,
1412                    PointwiseToLinalgConverter<lmhlo::NotOp>,
1413                    PointwiseToLinalgConverter<lmhlo::OrOp>,
1414                    PointwiseToLinalgConverter<lmhlo::PowOp>,
1415                    PointwiseToLinalgConverter<lmhlo::RealOp>,
1416                    PointwiseToLinalgConverter<lmhlo::RemOp>,
1417                    PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
1418                    PointwiseToLinalgConverter<lmhlo::SelectOp>,
1419                    PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
1420                    PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
1421                    PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
1422                    PointwiseToLinalgConverter<lmhlo::SignOp>,
1423                    PointwiseToLinalgConverter<lmhlo::SinOp>,
1424                    PointwiseToLinalgConverter<lmhlo::SqrtOp>,
1425                    PointwiseToLinalgConverter<lmhlo::SubOp>,
1426                    PointwiseToLinalgConverter<lmhlo::TanhOp>,
1427                    PointwiseToLinalgConverter<lmhlo::XorOp>,
1428                    ReduceConverter,
1429                    ReshapeOpConverter<lmhlo::ReshapeOp>,
1430                    ReverseConverter<lmhlo::ReverseOp>,
1431                    ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
1432                    ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
1433                    SliceConverter,
1434                    TransposeConverter<lmhlo::TransposeOp>
1435                   >(context);
1436   // clang-format on
1437 }
1438 
1439 // Converts LHLO ops to Linalg generic.
1440 // Sample result for lmhlo::AddOp.
1441 //
1442 // "lmhlo.add"(%arg1, %arg2, %out) :
1443 //      (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
1444 //
1445 // will be converted to
1446 //
1447 // #map0 = (d0, d1) -> (d0, d1)
1448 // "linalg.generic"(%arg1, %arg2, %out) ( {
1449 //   ^bb0(%arg4: f32, %arg5: f32):
1450 //     %0 = addf %arg4, %arg5 : f32
1451 //     "linalg.yield"(%0) : (f32) -> ()
1452 // }) {
1453 //     indexing_maps = [#map0, #map0, #map0],
1454 //     iterator_types = ["parallel", "parallel"],
1455 // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
1456 struct LhloLegalizeToLinalgPass
1457     : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
getDependentDialectsmlir::__anon57cf929c0111::LhloLegalizeToLinalgPass1458   void getDependentDialects(DialectRegistry& registry) const override {
1459     registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
1460   }
1461 
runOnFunctionmlir::__anon57cf929c0111::LhloLegalizeToLinalgPass1462   void runOnFunction() override {
1463     OwningRewritePatternList patterns;
1464     ConversionTarget target(getContext());
1465     target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
1466                            math::MathDialect, StandardOpsDialect,
1467                            AffineDialect>();
1468 
1469     auto func = getFunction();
1470     populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
1471     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
1472       signalPassFailure();
1473     }
1474   }
1475 };
1476 
1477 struct HloLegalizeToLinalgPass
1478     : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
getDependentDialectsmlir::__anon57cf929c0111::HloLegalizeToLinalgPass1479   void getDependentDialects(DialectRegistry& registry) const override {
1480     registry.insert<linalg::LinalgDialect, scf::SCFDialect,
1481                     complex::ComplexDialect, math::MathDialect>();
1482   }
1483 
runOnFunctionmlir::__anon57cf929c0111::HloLegalizeToLinalgPass1484   void runOnFunction() override {
1485     OwningRewritePatternList patterns;
1486     ConversionTarget target(getContext());
1487     target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
1488                            math::MathDialect, StandardOpsDialect,
1489                            tensor::TensorDialect, scf::SCFDialect>();
1490 
1491     auto func = getFunction();
1492     mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
1493     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
1494       signalPassFailure();
1495     }
1496   }
1497 };
1498 
1499 }  // namespace
1500 
1501 namespace lmhlo {
createLegalizeLhloToLinalgPass()1502 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
1503   return std::make_unique<LhloLegalizeToLinalgPass>();
1504 }
1505 }  // namespace lmhlo
1506 
1507 namespace mhlo {
1508 
populateHLOToLinalgConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)1509 void populateHLOToLinalgConversionPattern(MLIRContext* context,
1510                                           OwningRewritePatternList* patterns) {
1511   patterns->insert<
1512       BroadcastConverter<mhlo::BroadcastOp, false>,
1513       ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter,
1514       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
1515       PointwiseToLinalgConverter<mhlo::AbsOp, false>,
1516       PointwiseToLinalgConverter<mhlo::AddOp, false>,
1517       PointwiseToLinalgConverter<mhlo::AndOp, false>,
1518       PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
1519       PointwiseToLinalgConverter<mhlo::CeilOp, false>,
1520       PointwiseToLinalgConverter<mhlo::ClampOp, false>,
1521       PointwiseToLinalgConverter<mhlo::CompareOp, false>,
1522       PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
1523       PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
1524       PointwiseToLinalgConverter<mhlo::CopyOp, false>,
1525       PointwiseToLinalgConverter<mhlo::CosOp, false>,
1526       PointwiseToLinalgConverter<mhlo::DivOp, false>,
1527       PointwiseToLinalgConverter<mhlo::ExpOp, false>,
1528       PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
1529       PointwiseToLinalgConverter<mhlo::FloorOp, false>,
1530       PointwiseToLinalgConverter<mhlo::ImagOp, false>,
1531       PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
1532       PointwiseToLinalgConverter<mhlo::LogOp, false>,
1533       PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
1534       PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
1535       PointwiseToLinalgConverter<mhlo::MaxOp, false>,
1536       PointwiseToLinalgConverter<mhlo::MinOp, false>,
1537       PointwiseToLinalgConverter<mhlo::MulOp, false>,
1538       PointwiseToLinalgConverter<mhlo::NegOp, false>,
1539       PointwiseToLinalgConverter<mhlo::NotOp, false>,
1540       PointwiseToLinalgConverter<mhlo::OrOp, false>,
1541       PointwiseToLinalgConverter<mhlo::PowOp, false>,
1542       PointwiseToLinalgConverter<mhlo::RealOp, false>,
1543       PointwiseToLinalgConverter<mhlo::RemOp, false>,
1544       PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
1545       PointwiseToLinalgConverter<mhlo::SelectOp, false>,
1546       PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
1547       PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
1548       PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
1549       PointwiseToLinalgConverter<mhlo::SignOp, false>,
1550       PointwiseToLinalgConverter<mhlo::SinOp, false>,
1551       PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
1552       PointwiseToLinalgConverter<mhlo::SubOp, false>,
1553       PointwiseToLinalgConverter<mhlo::TanhOp, false>,
1554       PointwiseToLinalgConverter<mhlo::XorOp, false>,
1555       ReshapeOpConverter<mhlo::ReshapeOp, false>,
1556       ReverseConverter<mhlo::ReverseOp, false>,
1557       TransposeConverter<mhlo::TransposeOp, false>, DotOpOnTensorsConversion,
1558       DotGeneralOpOnTensorsConversion, ReduceOnTensorsConversion>(context);
1559   patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
1560                    ReduceRegionXLAOpConversion<mhlo::MinOp>,
1561                    ReduceRegionXLAOpConversion<mhlo::MaxOp>,
1562                    ReduceRegionReturnOpConversion>(context);
1563 }
1564 
createLegalizeHloToLinalgPass()1565 std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
1566   return std::make_unique<HloLegalizeToLinalgPass>();
1567 }
1568 }  // namespace mhlo
1569 }  // namespace mlir
1570