• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
28 #include "mlir/Dialect/Affine/IR/AffineOps.h"
29 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
31 #include "mlir/Dialect/Math/IR/Math.h"
32 #include "mlir/Dialect/MemRef/IR/MemRef.h"
33 #include "mlir/Dialect/SCF/SCF.h"
34 #include "mlir/Dialect/Shape/IR/Shape.h"
35 #include "mlir/Dialect/StandardOps/IR/Ops.h"
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"
37 #include "mlir/IR/AffineExpr.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/Builders.h"
40 #include "mlir/IR/BuiltinOps.h"
41 #include "mlir/IR/BuiltinTypes.h"
42 #include "mlir/IR/Location.h"
43 #include "mlir/IR/MLIRContext.h"
44 #include "mlir/IR/Matchers.h"
45 #include "mlir/IR/Operation.h"
46 #include "mlir/IR/OperationSupport.h"
47 #include "mlir/IR/PatternMatch.h"
48 #include "mlir/IR/TypeUtilities.h"
49 #include "mlir/Pass/Pass.h"
50 #include "mlir/Pass/PassManager.h"
51 #include "mlir/Support/LogicalResult.h"
52 #include "mlir/Transforms/DialectConversion.h"
53 
54 namespace mlir {
55 namespace {
56 
57 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
58 /// are "parallel" except the last `nReduction` elements, where are "reduction"
59 /// attributes.
GetParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)60 SmallVector<StringRef, 3> GetParallelAndReductionIterators(
61     unsigned nLoops, unsigned nReduction) {
62   SmallVector<StringRef, 3> res(nLoops - nReduction,
63                                 getParallelIteratorTypeName());
64   res.append(nReduction, getReductionIteratorTypeName());
65   return res;
66 }
67 
GetNParallelLoopsAttrs(unsigned nParallelLoops)68 SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
69   return GetParallelAndReductionIterators(nParallelLoops, 0);
70 }
71 
72 template <bool isLHLO = true>
GetResultValue(Operation * op)73 Value GetResultValue(Operation* op) {
74   return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
75 }
76 
77 template <bool isLHLO = true>
GetHloOpResultType(Operation * op)78 ShapedType GetHloOpResultType(Operation* op) {
79   return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
80 }
81 
82 template <bool isLHLO = true>
VerifyHloOpBufferOrTensorSemantics(Operation * op)83 bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
84   auto verify_type = [&](Value val) -> bool {
85     return (isLHLO && val.getType().isa<MemRefType>()) ||
86            (!isLHLO && val.getType().isa<RankedTensorType>());
87   };
88   if (!llvm::all_of(op->getOperands(), verify_type)) return false;
89   return isLHLO ? op->getResults().empty()
90                 : llvm::all_of(op->getResults(), verify_type);
91 }
92 
GetInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dyn_sizes)93 Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
94                     ArrayRef<Value> dyn_sizes) {
95   return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
96                                         type.getElementType());
97 }
98 
ExtractDynamicSizes(OpBuilder & b,Location loc,Value tensor,Value shape_tensor=nullptr,AffineMap permutation={})99 SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
100                                           Value tensor,
101                                           Value shape_tensor = nullptr,
102                                           AffineMap permutation = {}) {
103   auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
104   if (!tensor_type) return {};
105   SmallVector<Value, 2> dyn_sizes(tensor_type.getRank());
106   for (auto& en : llvm::enumerate(tensor_type.getShape())) {
107     if (en.value() != ShapedType::kDynamicSize) continue;
108     // If a shape tensor is present extract from there.
109     if (shape_tensor) {
110       Value extract = b.create<tensor::ExtractOp>(
111           loc, shape_tensor,
112           ValueRange{b.create<ConstantIndexOp>(loc, en.index())});
113       dyn_sizes[en.index()] =
114           b.create<IndexCastOp>(loc, b.getIndexType(), extract);
115     } else {
116       dyn_sizes[en.index()] = b.create<tensor::DimOp>(loc, tensor, en.index());
117     }
118   }
119   if (permutation)
120     dyn_sizes = applyPermutationMap(permutation, makeArrayRef(dyn_sizes));
121   llvm::erase_value(dyn_sizes, nullptr);  // Strip out placeholders.
122   return dyn_sizes;
123 }
124 
Extract1DVector(DenseIntElementsAttr elements)125 SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
126   SmallVector<int64_t, 4> ret;
127   for (const APInt& element : elements) {
128     ret.push_back(element.getLimitedValue());
129   }
130   return ret;
131 }
132 
133 /// Returns the constant value associated with the init value if the defining
134 /// operation is a constant.
GetInitValueAsConst(Value init)135 Attribute GetInitValueAsConst(Value init) {
136   DenseElementsAttr attr;
137   if (!matchPattern(init, m_Constant(&attr))) return {};
138   auto type = attr.getType().dyn_cast<ShapedType>();
139   if (!type || type.getRank() != 0) return {};
140   return attr.getValue({});
141 }
142 
143 /// Returns a permutation AffineMap that puts all reduction dimensions to the
144 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
145 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
146 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
147 /// the AffineMap is returned.
GetTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reduction_dims)148 AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
149                                       ArrayRef<int64_t> reduction_dims) {
150   llvm::SmallSetVector<int, 4> s;
151   for (auto dim : reduction_dims) s.insert(dim);
152 
153   SmallVector<unsigned, 4> permutation;
154   for (int i = 0; i < rank; ++i)
155     if (!s.count(i)) permutation.push_back(i);
156   for (auto dim : reduction_dims) permutation.push_back(dim);
157 
158   auto map = AffineMap::getPermutationMap(permutation, context);
159   return inversePermutation(map);
160 }
161 
162 /// Returns true if the given `attr` is a splat of the given `value`.
isSplatValue(DenseIntElementsAttr attr,uint64_t value)163 bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
164   return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
165 }
166 
167 /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
168 /// follows a canonical form:
169 ///
170 /// * Input dimensions have order: (batch_count, spatial_dims,
171 ///   input_channel_count).
172 /// * Filter dimensions have order: (spatial_dims, input_channel_count,
173 ///   output_channel_count).
174 /// * Output dimensions have order: (batch_count, spatial_dims,
175 ///   output_channel_count).
176 template <typename DimensionNumbersTy>
HasCanonicalDimensionNumbers(const DimensionNumbersTy & dimension_numbers)177 static bool HasCanonicalDimensionNumbers(
178     const DimensionNumbersTy& dimension_numbers) {
179   const int input_spatial_rank =
180       llvm::size(dimension_numbers.input_spatial_dimensions());
181   // The dimensions for input should follow the order of
182   // batch_count, spatial_dims..., input_feature_count.
183   if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
184       dimension_numbers.input_feature_dimension().getInt() !=
185           (input_spatial_rank + 1)) {
186     return false;
187   }
188 
189   const int kernel_spatial_rank =
190       llvm::size(dimension_numbers.kernel_spatial_dimensions());
191   // The dimensions for filter should follow the order of
192   // spatial_dims..., input_feature_count, num_output_feature_count.
193   if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
194           kernel_spatial_rank ||
195       dimension_numbers.kernel_output_feature_dimension().getInt() !=
196           (kernel_spatial_rank + 1)) {
197     return false;
198   }
199 
200   const int output_spatial_rank =
201       llvm::size(dimension_numbers.output_spatial_dimensions());
202   // The dimensions for output should follow the order of
203   // batch_count, spatial_dims.., output_feature_count.
204   if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
205       dimension_numbers.output_feature_dimension().getInt() !=
206           (output_spatial_rank + 1)) {
207     return false;
208   }
209 
210   if (input_spatial_rank != output_spatial_rank ||
211       input_spatial_rank != kernel_spatial_rank) {
212     return false;
213   }
214 
215   auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin();
216   auto kernel_spatial_dim =
217       dimension_numbers.kernel_spatial_dimensions().begin();
218   auto output_spatial_dim =
219       dimension_numbers.output_spatial_dimensions().begin();
220   // Check spatial dims are ordered correctly.
221   for (int i = 0; i < input_spatial_rank; ++i) {
222     const int dim = i + 1;
223     if ((*input_spatial_dim++).getZExtValue() != dim ||
224         (*output_spatial_dim++).getZExtValue() != dim ||
225         (*kernel_spatial_dim++).getZExtValue() != i) {
226       return false;
227     }
228   }
229 
230   return true;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // mhlo.Einsum conversion patterns.
235 //===----------------------------------------------------------------------===//
236 
237 // Looks through a set of dimension that has been marked as reduction axes,
238 // if it is found within the set, then we set it as "reduction", otherwise
239 // we can label it as "parallel".
GetEinsumLoopsAttrs(const llvm::SmallSetVector<StringRef,4> & input_ind,const llvm::SmallSetVector<StringRef,4> & reduction_dims)240 SmallVector<StringRef, 3> GetEinsumLoopsAttrs(
241     const llvm::SmallSetVector<StringRef, 4>& input_ind,
242     const llvm::SmallSetVector<StringRef, 4>& reduction_dims) {
243   SmallVector<StringRef, 3> res;
244   for (StringRef dim : input_ind) {
245     if (!reduction_dims.contains(dim)) {
246       res.push_back(getParallelIteratorTypeName());
247     } else {
248       res.push_back(getReductionIteratorTypeName());
249     }
250   }
251   return res;
252 }
253 
ExtractDynamicEinsumSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,const SmallVector<std::string> & lhs_loop_vec,const SmallVector<std::string> & rhs_loop_vec,const SmallVector<std::string> & output_loop_vec)254 SmallVector<Value, 2> ExtractDynamicEinsumSizes(
255     OpBuilder& b, Location loc, Value lhs, Value rhs,
256     const SmallVector<std::string>& lhs_loop_vec,
257     const SmallVector<std::string>& rhs_loop_vec,
258     const SmallVector<std::string>& output_loop_vec) {
259   SmallVector<Value, 2> dyn_sizes;
260   for (const std::string& dim_ind : output_loop_vec) {
261     Value dim_size;
262     auto dim_ind_it =
263         std::find(lhs_loop_vec.begin(), lhs_loop_vec.end(), dim_ind);
264     if (dim_ind_it != lhs_loop_vec.end()) {
265       // Query from lhs vars.
266       auto dim_ind_pos = dim_ind_it - lhs_loop_vec.begin();
267       auto lhs_shape = lhs.getType().dyn_cast<RankedTensorType>().getShape();
268       if (lhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
269       dim_size = b.create<tensor::DimOp>(loc, lhs, dim_ind_pos);
270     } else {
271       // query from rhs vars.
272       dim_ind_it = std::find(rhs_loop_vec.begin(), rhs_loop_vec.end(), dim_ind);
273       auto dim_ind_pos = dim_ind_it - rhs_loop_vec.begin();
274       auto rhs_shape = rhs.getType().dyn_cast<RankedTensorType>().getShape();
275       if (rhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
276       dim_size = b.create<tensor::DimOp>(loc, rhs, dim_ind_pos);
277     }
278     dyn_sizes.push_back(dim_size);
279   }
280   return dyn_sizes;
281 }
282 
283 // Adds indices/axes that are missing from output set.
FindSummationAxes(const llvm::SmallSetVector<StringRef,4> & input_set,const llvm::SmallSetVector<StringRef,4> & output_set)284 llvm::SmallSetVector<StringRef, 4> FindSummationAxes(
285     const llvm::SmallSetVector<StringRef, 4>& input_set,
286     const llvm::SmallSetVector<StringRef, 4>& output_set) {
287   llvm::SmallSetVector<StringRef, 4> summation_axes;
288   for (StringRef ind : input_set) {
289     if (!output_set.contains(ind)) summation_axes.insert(ind);
290   }
291   return summation_axes;
292 }
293 
294 // Given a 1:1 map from std::string -> affine dimension expression
295 // we can get the affine expression of dimensions that an
296 // operand will access based on the input_str of einsum_config.
297 // For example:
298 // let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2}
299 // for einsum_config "abc,cb->acb"
300 // first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2).
301 // second_input_operand will get umap[{"c","b"}] -> (d2, d1).
302 // ouput_operand will get umap[{"a","c","b"}] -> (d0, d2, d1).
GetExprFromConfig(const SmallVector<std::string> & loop_dims,const DenseMap<StringRef,AffineExpr> & str_affine_dim_umap)303 SmallVector<AffineExpr> GetExprFromConfig(
304     const SmallVector<std::string>& loop_dims,
305     const DenseMap<StringRef, AffineExpr>& str_affine_dim_umap) {
306   SmallVector<AffineExpr> exprs;
307   for (const auto& dim : loop_dims) {
308     exprs.push_back(str_affine_dim_umap.lookup(dim));
309   }
310   return exprs;
311 }
312 
313 // Convert mhlo.einsum op into linalg.generic.
314 // Algorithm in general 3 steps:
315 
316 // Step1) Dissect entire einsum_config to different operands
317 // e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}.
318 
319 // Step2) Split up the string into vector of the elements
320 // e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"],
321 // rhs:["c","d"], out:["a","b","d"]}.
322 
323 // Step3) Convert the vector into data access
324 // patern represented by affineMaps with affineDimensions e.g
325 // {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2],
326 // rhs:[d2,d3], out:[d0,d1,d3]}.
327 class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> {
328  public:
329   using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern;
330 
matchAndRewrite(mhlo::EinsumOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const331   LogicalResult matchAndRewrite(
332       mhlo::EinsumOp op, ArrayRef<Value> args,
333       ConversionPatternRewriter& rewriter) const final {
334     mhlo::EinsumOp::Adaptor adaptor(args);
335     auto get_rank = [](Value v) {
336       return v.getType().cast<ShapedType>().getRank();
337     };
338     auto einsum_config = op.einsum_config();
339 
340     // With the assumption of binary input operand and single output
341     // get the inputs and output operands' indices.
342     // einsum_config = "lhs_loop,rhs_loop->out_loop"
343     std::size_t pos_arrow = einsum_config.find(kArrow);
344     std::size_t pos_comma = einsum_config.find(kComma);
345 
346     StringRef lhs_loop = einsum_config.substr(0, pos_comma);
347     StringRef rhs_loop = einsum_config.substr(
348         pos_comma + kComma.size(), pos_arrow - (pos_comma + kComma.size()));
349     StringRef out_loop = einsum_config.substr(pos_arrow + kArrow.size());
350 
351     // Check for Invalid Configs.
352     // 1.Check that there is only maximum 2 inputs
353     // 2.Check that there is only maximum 1 output
354     // 3.Check that there is 1 kArrow
355     if (rhs_loop.find(kComma) != std::string::npos ||
356         out_loop.find(kComma) != std::string::npos ||
357         out_loop.find(kArrow) != std::string::npos) {
358       return rewriter.notifyMatchFailure(op, "Invalid einsum config!");
359     }
360 
361     // Find result type, if on tensors.
362     auto result_ty = this->typeConverter
363                          ->convertType(GetHloOpResultType</*isLHLO=*/false>(op))
364                          .dyn_cast<RankedTensorType>();
365 
366     // Check result type compatibility.
367     if (!result_ty || !(result_ty.getElementType().isSignlessIntOrFloat())) {
368       return rewriter.notifyMatchFailure(op, "Invalid result type");
369     }
370 
371     // Convert the representation to vector<string>.
372     SmallVector<std::string> lhs_ein =
373         GetEinsumConfigAsVector(lhs_loop, get_rank(adaptor.lhs()));
374     SmallVector<std::string> rhs_ein =
375         GetEinsumConfigAsVector(rhs_loop, get_rank(adaptor.rhs()));
376     SmallVector<std::string> out_ein =
377         GetEinsumConfigAsVector(out_loop, result_ty.getRank());
378 
379     if (!CheckBatchHasEqualRank(lhs_ein.size(), lhs_loop, rhs_ein.size(),
380                                 rhs_loop, out_ein.size(), out_loop)) {
381       return rewriter.notifyMatchFailure(
382           op, "Invalid elipsis('...') within einsum config!");
383     }
384 
385     // Find all unique indices in the input and output.
386     llvm::SmallSetVector<StringRef, 4> input_ind;
387     llvm::SmallSetVector<StringRef, 4> output_ind;
388 
389     input_ind.insert(lhs_ein.begin(), lhs_ein.end());
390     input_ind.insert(rhs_ein.begin(), rhs_ein.end());
391     output_ind.insert(out_ein.begin(), out_ein.end());
392 
393     llvm::SmallSetVector<StringRef, 4> reduction_axe =
394         FindSummationAxes(input_ind, output_ind);
395 
396     // Find input/output values and types.
397     auto loc = op.getLoc();
398 
399     // Prepare init tensor for linalg.generic op.
400     auto dyn_sizes = ExtractDynamicEinsumSizes(
401         rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhs_ein, rhs_ein, out_ein);
402     Value output = GetInitTensor(rewriter, loc, result_ty, dyn_sizes);
403     if (!reduction_axe.empty()) {
404       auto zero_attr = rewriter.getZeroAttr(result_ty.getElementType());
405       Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
406       output = rewriter.create<linalg::FillOp>(loc, zero, output).getResult(0);
407     }
408 
409     // Create indexing maps.
410     // Create a 1:1 map from f:strDimension -> affineDimension.
411     int64_t nloops = input_ind.size();
412     DenseMap<StringRef, AffineExpr> str_affine_dim_umap;
413     for (auto it : llvm::enumerate(input_ind)) {
414       str_affine_dim_umap[it.value()] = rewriter.getAffineDimExpr(it.index());
415     }
416 
417     // From einsum_config of each operand in vector<string>, generate
418     // the equivalent vector<AffineExpr>.
419     SmallVector<AffineMap, 4> maps;
420     for (const SmallVector<std::string>& loop_operand :
421          {lhs_ein, rhs_ein, out_ein}) {
422       auto exprs = GetExprFromConfig(loop_operand, str_affine_dim_umap);
423       maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext()));
424     }
425 
426     auto linalg_op = rewriter.create<linalg::GenericOp>(
427         loc, result_ty ? result_ty : TypeRange{}, args, output, maps,
428         GetEinsumLoopsAttrs(input_ind, reduction_axe),
429         [&](OpBuilder& b, Location nested_loc, ValueRange args) {
430           Value result_val =
431               b.create<mlir::MulFOp>(nested_loc, args[0], args[1]);
432           if (!reduction_axe.empty()) {
433             result_val =
434                 b.create<mlir::AddFOp>(nested_loc, args[2], result_val);
435           }
436           b.create<linalg::YieldOp>(nested_loc, result_val);
437         });
438     rewriter.replaceOp(op, linalg_op.getResults());
439     return success();
440   }
441 
442  private:
443   static constexpr StringRef kArrow = "->";
444   static constexpr StringRef kComma = ",";
445   static constexpr StringRef kEllipsis = "...";
446 
447   static bool CheckBatchHasEqualRank(size_t lhs_rank, StringRef lhs_loop,
448                                      size_t rhs_rank, StringRef rhs_loop,
449                                      size_t out_rank, StringRef out_loop);
450   static SmallVector<std::string> GetEinsumConfigAsVector(StringRef loop,
451                                                           size_t operand_rank);
452 };
453 
454 // Definition of util const member variables.
455 constexpr StringRef EinsumToLinalgConverter::kArrow;
456 constexpr StringRef EinsumToLinalgConverter::kComma;
457 constexpr StringRef EinsumToLinalgConverter::kEllipsis;
458 
459 // Convert the representation from string/vector<char> to vector<string>.
460 // i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3:
461 // get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"}
GetEinsumConfigAsVector(StringRef loop,size_t operand_rank)462 SmallVector<std::string> EinsumToLinalgConverter::GetEinsumConfigAsVector(
463     StringRef loop, size_t operand_rank) {
464   SmallVector<std::string> loop_dim;
465   size_t pre_elip = loop.find(kEllipsis);
466   bool has_elip = pre_elip != std::string::npos;
467   if (!has_elip) pre_elip = loop.size();
468   // Add the dimension until the end or up to ellipsis if it exist.
469   for (int pre_elip_ind = 0; pre_elip_ind < pre_elip; pre_elip_ind++) {
470     loop_dim.push_back(loop.substr(pre_elip_ind, 1).str());
471   }
472   if (!has_elip) return loop_dim;
473   // Case where Ellipsis presence:
474   size_t non_batch_rank = loop.size() - kEllipsis.size();
475   size_t batch_rank = operand_rank - non_batch_rank;
476   // Add the batch dimension ("0",...,"N") where N is rank of batch into the
477   // loop.
478   for (int batch_ind = 0; batch_ind < batch_rank; batch_ind++) {
479     loop_dim.push_back(std::to_string(batch_ind));
480   }
481   // Add the dimension after ellipsis into the loop.
482   int post_elip = pre_elip + kEllipsis.size();
483   for (int post_elip_ind = post_elip; post_elip_ind < loop.size();
484        post_elip_ind++) {
485     loop_dim.push_back(loop.substr(post_elip_ind, 1).str());
486   }
487   return loop_dim;
488 }
489 
490 // Returns true if all operand's batch has same rank.
CheckBatchHasEqualRank(size_t lhs_rank,StringRef lhs_loop,size_t rhs_rank,StringRef rhs_loop,size_t out_rank,StringRef out_loop)491 bool EinsumToLinalgConverter::CheckBatchHasEqualRank(
492     size_t lhs_rank, StringRef lhs_loop, size_t rhs_rank, StringRef rhs_loop,
493     size_t out_rank, StringRef out_loop) {
494   SmallVector<int, 3> batch_rank_vec;
495   if (lhs_rank != lhs_loop.size()) {
496     size_t lhs_batch_rank = lhs_rank - (lhs_loop.size() - kEllipsis.size());
497     batch_rank_vec.push_back(lhs_batch_rank);
498   }
499   if (rhs_rank != rhs_loop.size()) {
500     size_t rhs_batch_rank = rhs_rank - (rhs_loop.size() - kEllipsis.size());
501     batch_rank_vec.push_back(rhs_batch_rank);
502   }
503   if (out_rank != out_loop.size()) {
504     size_t out_batch_rank = out_rank - (out_loop.size() - kEllipsis.size());
505     batch_rank_vec.push_back(out_batch_rank);
506   }
507   bool batch_has_equal_rank = true;
508 
509   // Condition is valid if only 1 operand or less have batches.
510   if (batch_rank_vec.size() < 2) return batch_has_equal_rank;
511   if (!std::equal(batch_rank_vec.begin() + 1, batch_rank_vec.end(),
512                   batch_rank_vec.begin()) &&
513       batch_rank_vec.size() > 1)
514     batch_has_equal_rank = false;
515   return batch_has_equal_rank;
516 }
517 
518 template <typename OpTy, bool isLHLO = true>
519 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
520  public:
521   using OpConversionPattern<OpTy>::OpConversionPattern;
522 
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const523   LogicalResult matchAndRewrite(
524       OpTy op, ArrayRef<Value> args,
525       ConversionPatternRewriter& rewriter) const final {
526     // Find maximum rank / number of loops.
527     auto get_rank = [](Value v) {
528       return v.getType().cast<ShapedType>().getRank();
529     };
530     auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
531     auto it = llvm::find_if_not(args, is_scalar);
532     Value max_rank_arg = it != args.end() ? *it : args.front();
533     int64_t nloops = get_rank(max_rank_arg);
534 
535     if (isLHLO && nloops == 0) return failure();
536 
537     // Apply only if all operands are scalar or have the same rank. Some ops,
538     // like `mhlo.select`, support implicit broadcasting of scalars.
539     if (!llvm::all_of(args, [&](Value v) {
540           int64_t r = get_rank(v);
541           return r == 0 || r == nloops;
542         })) {
543       return rewriter.notifyMatchFailure(
544           op, "Operands must be os same rank or scalar.");
545     }
546 
547     // Find result type, if on tensors.
548     Optional<ShapedType> result_ty;
549     if (!isLHLO) {
550       result_ty = this->typeConverter->convertType(op->getResultTypes().front())
551                       .template dyn_cast<ShapedType>();
552 
553       // Check result type compatibility.
554       if (!result_ty || !result_ty->hasRank() ||
555           result_ty->getRank() != nloops ||
556           !(result_ty->getElementType().isSignlessIntOrFloat() ||
557             result_ty->getElementType().isa<ComplexType>())) {
558         return rewriter.notifyMatchFailure(
559             op, "mismatched operand/result types or iterator count");
560       }
561     }
562 
563     // Find input/output values and types.
564     auto loc = op.getLoc();
565     ValueRange inputs = isLHLO ? args.drop_back() : args;
566     Value output;
567     if (isLHLO) {
568       output = args.back();
569     } else {
570       auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg);
571       output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes);
572     }
573 
574     // Create indexing maps.
575     AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
576     AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
577     SmallVector<AffineMap, 4> maps;
578     for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
579     maps.push_back(id_map);
580 
581     // Build `linalg.generic` op.
582     bool failed = false;
583     auto linalg_op = rewriter.create<linalg::GenericOp>(
584         loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
585         GetNParallelLoopsAttrs(nloops),
586         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
587           // TODO(ravishankarm) : For now use the method in lmhlo namespace.
588           // That method needs to be moved out of there.
589           Type inner_result_ty = getElementTypeOrSelf(output);
590           Value inner_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
591               op, inner_result_ty,
592               llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
593           if (inner_result == nullptr) {
594             failed = true;
595           } else {
596             nested_builder.create<linalg::YieldOp>(loc, inner_result);
597           }
598         });
599     if (failed) return failure();
600     rewriter.replaceOp(op, linalg_op->getResults());
601     return success();
602   }
603 };
604 
605 template <typename LhloOp>
606 class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
607  public:
608   using OpConversionPattern<LhloOp>::OpConversionPattern;
609 
matchAndRewrite(LhloOp lhlo_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const610   LogicalResult matchAndRewrite(
611       LhloOp lhlo_op, ArrayRef<Value> args,
612       ConversionPatternRewriter& rewriter) const final {
613     auto loc = lhlo_op.getLoc();
614     auto arg_type =
615         lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
616     if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
617         (arg_type.getRank() != 0)) {
618       return failure();
619     }
620 
621     // Create two loads from the input.
622     auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
623     auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
624     // TODO(ravishankarm) : Move this method out of lmhlo namespace.
625     Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
626         lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
627         &rewriter);
628     rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
629     rewriter.eraseOp(lhlo_op);
630     return success();
631   }
632 };
633 
634 //===----------------------------------------------------------------------===//
635 // lmhlo.convolution conversion pattern.
636 //===----------------------------------------------------------------------===//
637 
638 /// Converts lmhlo.convolution operation to a linalg.conv op.
639 struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
640  public:
641   using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
642 
matchAndRewritemlir::__anonb93b5a210111::ConvToLinalgConverter643   LogicalResult matchAndRewrite(
644       lmhlo::ConvOp op, ArrayRef<Value> args,
645       ConversionPatternRewriter& rewriter) const final {
646     if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
647 
648     // TODO: LHS dilation for deconvolution not supported yet.
649     // TODO(jurahul): Window reversal is not supported yet.
650     if (op.lhs_dilation() || op.hasWindowReversal()) {
651       return failure();
652     }
653 
654     llvm::SmallVector<Attribute, 4> strides;
655     if (auto window_strides = op.window_strides()) {
656       auto range = window_strides->getAttributeValues();
657       strides.assign(range.begin(), range.end());
658     }
659     auto strides_arg = ArrayAttr::get(op.getContext(), strides);
660 
661     llvm::SmallVector<Attribute, 2> dilation;
662     if (auto rhs_dilation = op.rhs_dilation()) {
663       auto range = rhs_dilation->getAttributeValues();
664       dilation.assign(range.begin(), range.end());
665     } else {
666       // Default dilation of 1.
667       dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
668     }
669     auto dilation_arg = ArrayAttr::get(op.getContext(), dilation);
670 
671     // Set padding only if it is non-zero.
672     DenseIntElementsAttr padding = op.paddingAttr();
673     if (!padding ||
674         !llvm::any_of(padding.getValues<APInt>(),
675                       [](APInt int_val) { return !int_val.isNullValue(); })) {
676       padding = nullptr;
677     }
678 
679     // The order of input and filter are switched with linalg.conv.
680     rewriter.replaceOpWithNewOp<linalg::ConvOp>(
681         op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
682     return success();
683   }
684 };
685 
686 /// Base class for lowering HLO operations that have one operand and one result,
687 /// and are semantically equivalent to a copy of the input to the output (like
688 /// transpose, some reshape, etc.). The derived classes need to provide a method
689 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
690 /// and the output.
691 template <typename Derived, typename OpTy, bool isLHLO = true>
692 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
693  public:
694   using OpConversionPattern<OpTy>::OpConversionPattern;
695 
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const696   LogicalResult matchAndRewrite(
697       OpTy op, ArrayRef<Value> args,
698       ConversionPatternRewriter& rewriter) const final {
699     if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
700     auto result_type = GetHloOpResultType<isLHLO>(op);
701     result_type = this->typeConverter->convertType(result_type)
702                       .template cast<ShapedType>();
703 
704     SmallVector<AffineMap, 2> indexing_maps =
705         Derived::getIndexingMaps(op, &rewriter);
706     if (indexing_maps.empty()) return failure();
707 
708     auto nloops = result_type.getRank();
709     auto loc = op.getLoc();
710     AffineMap shape_permutation =
711         indexing_maps[0].isPermutation() ? indexing_maps[0] : AffineMap();
712     // TODO(pifon): technically, the op itself could have size operands (e.g.
713     // broadcast into a dynamic dimension).Handle this case.
714     auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
715                             : ExtractDynamicSizes(rewriter, loc, args[0],
716                                                   nullptr, shape_permutation);
717     auto linalg_op = rewriter.create<linalg::GenericOp>(
718         loc,
719         /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
720         /*inputs=*/args.front(),
721         /*outputBuffers=*/
722         isLHLO
723             ? ValueRange{args.back()}
724             : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
725         indexing_maps, GetNParallelLoopsAttrs(nloops),
726         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
727           nested_builder.create<linalg::YieldOp>(loc, *args.begin());
728         });
729     rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
730     return success();
731   }
732 };
733 
734 /// Pattern to convert BroadcastOp to Linalg ops.
735 template <typename OpTy, bool isLHLO = true>
736 class BroadcastConverter
737     : public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
738                                      isLHLO> {
739  public:
740   using DataMovementOpConverter<BroadcastConverter, OpTy,
741                                 isLHLO>::DataMovementOpConverter;
742 
getIndexingMaps(OpTy broadcast_op,Builder * b)743   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
744                                                    Builder* b) {
745     ShapedType input_type =
746         broadcast_op.operand().getType().template cast<ShapedType>();
747     unsigned input_rank = input_type.getRank();
748     unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
749 
750     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
751     // the input's dimensions.
752     unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
753     SmallVector<AffineExpr, 4> input_dim_exprs;
754     input_dim_exprs.reserve(input_rank);
755     for (unsigned i = 0; i < input_rank; ++i) {
756       input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
757     }
758 
759     AffineMap input_map;
760     MLIRContext* context = b->getContext();
761     if (input_dim_exprs.empty()) {
762       // The input is a scalar, i.e. this is a scalar broadcast op.
763       input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
764     } else {
765       input_map =
766           AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
767     }
768     return {input_map, b->getMultiDimIdentityMap(nloops)};
769   }
770 };
771 
772 class HloBroadcastInDimConverter
773     : public DataMovementOpConverter<HloBroadcastInDimConverter,
774                                      mhlo::BroadcastInDimOp, false> {
775  public:
776   using DataMovementOpConverter<HloBroadcastInDimConverter,
777                                 mhlo::BroadcastInDimOp,
778                                 false>::DataMovementOpConverter;
779 
getIndexingMaps(mhlo::BroadcastInDimOp broadcast_op,Builder * b)780   static SmallVector<AffineMap, 2> getIndexingMaps(
781       mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
782     auto result_type = GetHloOpResultType<false>(broadcast_op);
783     auto operand_type =
784         broadcast_op.operand().getType().template cast<ShapedType>();
785     unsigned nloops = result_type.getRank();
786 
787     // The input is a scalar, i.e. this is a scalar broadcast op.
788     if (operand_type.getRank() == 0) {
789       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
790               b->getMultiDimIdentityMap(nloops)};
791     }
792 
793     auto operand_shape = operand_type.getShape();
794     SmallVector<AffineExpr, 4> dim_exprs;
795     dim_exprs.reserve(nloops);
796 
797     if (broadcast_op.broadcast_dimensions()) {
798       for (const auto& broadcastDim :
799            enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
800         int size = broadcastDim.value().getSExtValue();
801         bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
802                                 result_type.getShape()[size] != 1;
803         dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
804                                              : b->getAffineDimExpr(size));
805       }
806     }
807     return {
808         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
809         b->getMultiDimIdentityMap(nloops)};
810   }
811 };
812 
813 class HloDynamicBroadcastInDimConverter
814     : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
815  public:
816   using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
817 
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const818   LogicalResult matchAndRewrite(
819       mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
820       ConversionPatternRewriter& rewriter) const final {
821     // If the input has a static shape we know exactly when the broadcast must
822     // expand (the dimension is 1, which also trivially expands to 1) or will
823     // never expand (the dimension is not 1). This means we can lower the
824     // broadcast just as we would lower a fully static broadcast and go directly
825     // to linalg.generic. This also covers the important case of broadcasting a
826     // scalar.
827 
828     // Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`)
829     // should be converted to an Tensor-dialect op similar to TF ConstantLikeOp.
830 
831     mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(operands);
832     Value operand = adaptor.operand();
833     auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
834     if (!operand_type || !operand_type.hasStaticShape()) return failure();
835 
836     Value shape = adaptor.output_dimensions();
837     auto shape_type = shape.getType().cast<RankedTensorType>();
838     int64_t result_rank = shape_type.getDimSize(0);
839     // HLO dimension types can be any integer, as well as index.
840     bool convert_to_index =
841         shape_type.getElementType() != rewriter.getIndexType();
842 
843     auto result_type =
844         typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>();
845     if (!result_type) return failure();
846 
847     SmallVector<Value, 2> dyn_dims;
848     Location loc = op.getLoc();
849     for (int i = 0; i < result_rank; ++i) {
850       if (!result_type.isDynamicDim(i)) continue;
851       Value index = rewriter.create<ConstantIndexOp>(loc, i);
852       Value dim = rewriter.create<tensor::ExtractOp>(loc, shape, index);
853       if (convert_to_index) {
854         dim = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), dim);
855       }
856       dyn_dims.push_back(dim);
857     }
858 
859     int64_t nloops = result_type.getRank();
860     auto operand_shape = operand_type.getShape();
861     SmallVector<AffineExpr, 4> dim_exprs;
862     dim_exprs.reserve(nloops);
863 
864     if (op.broadcast_dimensions()) {
865       for (const auto& broadcast_dim :
866            enumerate(op.broadcast_dimensions().getIntValues())) {
867         int64_t size = broadcast_dim.value().getSExtValue();
868         bool expansion_needed = operand_shape[broadcast_dim.index()] == 1;
869         dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0)
870                                              : rewriter.getAffineDimExpr(size));
871       }
872     }
873 
874     Value init = rewriter.create<linalg::InitTensorOp>(
875         loc, dyn_dims, result_type.getShape(), result_type.getElementType());
876     Operation* generic = rewriter.create<linalg::GenericOp>(
877         loc, TypeRange{init.getType()}, ValueRange{operand},
878         /*outputBuffers=*/ValueRange{init},
879         llvm::makeArrayRef(
880             {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
881                             rewriter.getContext()),
882              rewriter.getMultiDimIdentityMap(nloops)}),
883         GetNParallelLoopsAttrs(nloops),
884         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
885           nested_builder.create<linalg::YieldOp>(loc, *args.begin());
886         });
887     rewriter.replaceOp(op, generic->getResults());
888     return success();
889   }
890 };
891 
892 class LhloBroadcastInDimConverter
893     : public OpConversionPattern<lmhlo::BroadcastInDimOp> {
894  public:
895   using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
896 
matchAndRewrite(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const897   LogicalResult matchAndRewrite(
898       lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
899       ConversionPatternRewriter& rewriter) const final {
900     lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
901     auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
902     auto result_shape = result_type.getShape();
903 
904     auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
905 
906     Value operand = std::get<0>(operand_and_dims);
907     auto broadcast_dims = std::get<1>(operand_and_dims);
908 
909     auto loc = op.getLoc();
910     auto nloops = result_type.getRank();
911     auto operand_type = operand.getType().cast<MemRefType>();
912 
913     // For a degenerate case, i.e. broadcasting with expansion of
914     // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
915     // Instead the value is loaded and used directly in `linalg.yield`.
916     if (operand_type.getRank() == 1 &&
917         operand_type.getDimSize(0) <
918             result_type.getDimSize(broadcast_dims.front())) {
919       Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
920       Value val = rewriter.create<memref::LoadOp>(loc, operand,
921                                                   llvm::makeArrayRef({zero}));
922       rewriter.create<linalg::GenericOp>(
923           loc, /*inputs=*/ValueRange{},
924           /*outputBuffers=*/ValueRange{operand_adaptor.output()},
925           llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
926           GetNParallelLoopsAttrs(nloops),
927           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
928             nested_builder.create<linalg::YieldOp>(loc, val);
929           });
930 
931     } else {
932       auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
933                                            operand_type, &rewriter);
934       rewriter.create<linalg::GenericOp>(
935           loc, /*inputs=*/ValueRange{operand},
936           /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
937           GetNParallelLoopsAttrs(nloops),
938           [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
939             nested_builder.create<linalg::YieldOp>(loc, *args.begin());
940           });
941     }
942     rewriter.replaceOp(op, llvm::None);
943     return success();
944   }
945 
946   // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
InsertReshapeIfNecessary(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const947   std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
948       lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
949       ConversionPatternRewriter& rewriter) const {
950     lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
951     Value operand = operand_adaptor.operand();
952     auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
953     auto operand_shape = operand_type.getShape();
954 
955     Value result = operand_adaptor.output();
956     auto result_type = result.getType().cast<MemRefType>();
957     auto result_shape = result_type.getShape();
958 
959     SmallVector<int64_t, 2> operand_strides;
960     int64_t operand_offset;
961     if (failed(getStridesAndOffset(operand_type, operand_strides,
962                                    operand_offset))) {
963       op.emitOpError() << "Failed to get offset and strides.";
964     }
965 
966     SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
967     SmallVector<ReassociationIndices, 4> collapsed_dims_list;
968     ReassociationIndices collapsed_dims;
969     for (const auto& item :
970          enumerate(op.broadcast_dimensions().getIntValues())) {
971       size_t index = item.index();
972       int dim = item.value().getSExtValue();
973 
974       collapsed_dims.push_back(index);
975 
976       bool expansion_needed =
977           operand_shape[index] == 1 && result_shape[dim] != 1;
978       if (expansion_needed) {
979         continue;
980       }
981       new_shape.push_back(operand_shape[index]);
982       new_strides.push_back(operand_strides[index]);
983       broadcast_dims.push_back(dim);
984 
985       collapsed_dims_list.push_back(collapsed_dims);
986       collapsed_dims.clear();
987     }
988     // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
989     // and all dimensions need expansion. Such memref will be reshaped to a 1D
990     // memref with a single element. New shape and strides needs to be updated
991     // accordingly.
992     if (collapsed_dims_list.empty()) {
993       collapsed_dims_list.push_back({});
994       new_shape.push_back(1);
995       new_strides.push_back(1);
996       broadcast_dims.push_back(0);
997     }
998     for (const auto& dims : collapsed_dims) {
999       collapsed_dims_list.back().push_back(dims);
1000     }
1001 
1002     // `linalg.collapse_shape` is inserted only if necessary, i.e. when the rank
1003     // can be reduced.
1004     if (new_shape.size() < operand_shape.size()) {
1005       auto new_memref_type = MemRefType::get(
1006           new_shape, operand_type.getElementType(),
1007           makeStridedLinearLayoutMap(new_strides, operand_offset,
1008                                      rewriter.getContext()));
1009       operand = rewriter.create<memref::CollapseShapeOp>(
1010           op.getLoc(), new_memref_type, operand_adaptor.operand(),
1011           collapsed_dims_list);
1012     }
1013     return std::make_pair(operand, broadcast_dims);
1014   }
1015 
getIndexingMaps(lmhlo::BroadcastInDimOp op,ArrayRef<int64_t> broadcast_dims,ArrayRef<int64_t> result_shape,MemRefType operand_type,Builder * b) const1016   SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
1017                                             ArrayRef<int64_t> broadcast_dims,
1018                                             ArrayRef<int64_t> result_shape,
1019                                             MemRefType operand_type,
1020                                             Builder* b) const {
1021     unsigned nloops = result_shape.size();
1022 
1023     // The input is a scalar, i.e. this is a scalar broadcast op.
1024     if (operand_type.getRank() == 0) {
1025       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
1026               b->getMultiDimIdentityMap(nloops)};
1027     }
1028 
1029     auto operand_shape = operand_type.getShape();
1030     SmallVector<AffineExpr, 4> dim_exprs;
1031     dim_exprs.reserve(nloops);
1032 
1033     for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
1034       int size = broadcast_dim.value();
1035       bool expansion_needed =
1036           operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
1037       if (expansion_needed) {
1038         op.emitOpError(
1039             "BroadcastInDimOp lowering to Linalg does not support size-1 "
1040             "dimensions expansion.");
1041       }
1042       dim_exprs.push_back(b->getAffineDimExpr(size));
1043     }
1044     return {
1045         AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
1046         b->getMultiDimIdentityMap(nloops)};
1047   }
1048 };
1049 
1050 template <typename OpTy, bool isLHLO = true>
1051 class TransposeConverter
1052     : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
1053                                      isLHLO> {
1054  public:
1055   using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
1056                                 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1057   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1058     auto result_type =
1059         GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1060     auto nloops = result_type.getRank();
1061     SmallVector<AffineExpr, 2> input_exprs;
1062     input_exprs.resize(result_type.getRank());
1063     for (auto permutation : llvm::enumerate(op.permutation())) {
1064       input_exprs[permutation.value().getZExtValue()] =
1065           b->getAffineDimExpr(permutation.index());
1066     }
1067     return {
1068         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1069         b->getMultiDimIdentityMap(nloops)};
1070   }
1071 };
1072 
1073 // Converts reshape ops that can be proven to be either a collapse of dimensions
1074 // or expansion of dimensions of the operand.
1075 template <typename OpTy, bool isLHLO = true>
1076 class ReshapeOpConverter : public OpConversionPattern<OpTy> {
1077  public:
1078   using OpConversionPattern<OpTy>::OpConversionPattern;
1079 
matchAndRewrite(OpTy reshape_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1080   LogicalResult matchAndRewrite(
1081       OpTy reshape_op, ArrayRef<Value> args,
1082       ConversionPatternRewriter& rewriter) const final {
1083     if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
1084       return failure();
1085     typename OpTy::Adaptor operands(args);
1086     ShapedType operand_type =
1087         operands.operand().getType().template cast<ShapedType>();
1088     ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
1089 
1090     if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
1091       return failure();
1092 
1093     result_type = this->typeConverter->convertType(result_type)
1094                       .template cast<ShapedType>();
1095 
1096     // Compute the reassociation maps for the linalg operation.
1097     ArrayRef<int64_t> src_shape =
1098         (operand_type.getRank() > result_type.getRank()
1099              ? operand_type.getShape()
1100              : result_type.getShape());
1101     ArrayRef<int64_t> dst_shape =
1102         (operand_type.getRank() > result_type.getRank()
1103              ? result_type.getShape()
1104              : operand_type.getShape());
1105     unsigned curr_src_dim = 0, curr_dst_dim = 0;
1106     SmallVector<ReassociationExprs, 4> reassociation_map(dst_shape.size());
1107 
1108     // First scan all dimensions in the source shapes to see whether we have a
1109     // perfect case where consecutive dimensions in source are collapsed. For
1110     // such case we can just generate one single linalg.reshape.
1111     bool is_collapsing_source = true;
1112     while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
1113       int64_t dst_size = dst_shape[curr_dst_dim];
1114       int64_t src_size = src_shape[curr_src_dim];
1115       while (src_size < dst_size && curr_src_dim < src_shape.size()) {
1116         reassociation_map[curr_dst_dim].push_back(
1117             rewriter.getAffineDimExpr(curr_src_dim++));
1118         src_size *= src_shape[curr_src_dim];
1119       }
1120       if (src_size == dst_size) {
1121         reassociation_map[curr_dst_dim].push_back(
1122             rewriter.getAffineDimExpr(curr_src_dim++));
1123         // If the next dim in dst_shape is not 1, treat subsequent dims in
1124         // src_shape which are 1 to be collapsed.
1125         if (curr_dst_dim == dst_shape.size() - 1 ||
1126             dst_shape[curr_dst_dim + 1] != 1) {
1127           while (curr_src_dim < src_shape.size() &&
1128                  src_shape[curr_src_dim] == 1) {
1129             reassociation_map[curr_dst_dim].push_back(
1130                 rewriter.getAffineDimExpr(curr_src_dim++));
1131           }
1132         }
1133       } else {
1134         is_collapsing_source = false;
1135         break;
1136       }
1137       curr_dst_dim++;
1138     }
1139     // Rank 0 can always use the direct lowering.
1140     if (!src_shape.empty() && !dst_shape.empty() &&
1141         (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()))
1142       is_collapsing_source = false;
1143 
1144     // Otherwise, we need to first reduce all source dimensions into one and
1145     // then expand to the destination dimensions.
1146     if (!is_collapsing_source) {
1147       auto get_identity_exprs = [&rewriter](int n) {
1148         SmallVector<AffineExpr, 4> exprs;
1149         for (int i = 0; i < n; ++i)
1150           exprs.push_back(rewriter.getAffineDimExpr(i));
1151         return exprs;
1152       };
1153       Location loc = reshape_op.getLoc();
1154       int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
1155                                             1, std::multiplies<int64_t>());
1156       auto elem_type = operand_type.getElementType();
1157       SmallVector<ReassociationExprs, 4> collapsing_map = {
1158           // Use operand_type here because we need to collapse all operands
1159           // dimensions.
1160           get_identity_exprs(operand_type.getShape().size())};
1161       SmallVector<ReassociationExprs, 4> expanding_map = {
1162           // Use result_type here because we need to expand to all result
1163           // dimensions.
1164           get_identity_exprs(result_type.getShape().size())};
1165 
1166       if (isLHLO) {
1167         auto collapsed_type = MemRefType::get({total_elems}, elem_type);
1168         Value collapsed_op = rewriter.create<memref::CollapseShapeOp>(
1169             loc, collapsed_type, args[0], collapsing_map);
1170         Value reshape_buffer = rewriter.create<memref::ExpandShapeOp>(
1171             loc, result_type, collapsed_op, expanding_map);
1172         rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
1173                                                     args[1]);
1174       } else {
1175         auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
1176         Value collapsed_op = rewriter.create<linalg::TensorCollapseShapeOp>(
1177             loc, collapsed_type, args[0], collapsing_map);
1178         rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1179             reshape_op, result_type, collapsed_op, expanding_map);
1180       }
1181       return success();
1182     }
1183 
1184     bool isCollapsing =
1185         result_type.getRank() < args[0].getType().cast<ShapedType>().getRank();
1186     if (isLHLO) {
1187       Value reshape_buffer = isCollapsing
1188                                  ? rewriter
1189                                        .create<memref::CollapseShapeOp>(
1190                                            reshape_op.getLoc(), result_type,
1191                                            args[0], reassociation_map)
1192                                        .getResult()
1193                                  : rewriter
1194                                        .create<memref::ExpandShapeOp>(
1195                                            reshape_op.getLoc(), result_type,
1196                                            args[0], reassociation_map)
1197                                        .getResult();
1198       rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
1199                                                   args[1]);
1200     } else {
1201       if (isCollapsing) {
1202         rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1203             reshape_op, result_type, args[0], reassociation_map);
1204       } else {
1205         rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1206             reshape_op, result_type, args[0], reassociation_map);
1207       }
1208     }
1209     return success();
1210   }
1211 };
1212 
1213 template <typename OpTy, bool isLHLO = true>
1214 class IotaConverter : public OpConversionPattern<OpTy> {
1215  public:
1216   using OpConversionPattern<OpTy>::OpConversionPattern;
1217 
matchAndRewrite(OpTy iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1218   LogicalResult matchAndRewrite(
1219       OpTy iota_op, ArrayRef<Value> args,
1220       ConversionPatternRewriter& rewriter) const final {
1221     ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
1222     if (!result_shaped_type) return failure();
1223     result_shaped_type = this->typeConverter->convertType(result_shaped_type)
1224                              .template dyn_cast<ShapedType>();
1225 
1226     auto result_element_type = result_shaped_type.getElementType();
1227     if (!result_element_type.isSignlessIntOrFloat()) return failure();
1228 
1229     // Construct the indexing maps needed for linalg.generic ops.
1230     unsigned nloops = result_shaped_type.getRank();
1231 
1232     Location loc = iota_op.getLoc();
1233     // If this is a dynamic iota, the first argument will be a shape tensor.
1234     Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr;
1235     auto dyn_sizes =
1236         isLHLO
1237             ? SmallVector<Value, 2>()
1238             : ExtractDynamicSizes(
1239                   rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
1240     auto linalg_op = rewriter.create<linalg::GenericOp>(
1241         loc,
1242         /*resultTensorTypes=*/
1243         isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
1244         /*inputs=*/ValueRange{},
1245         /*outputBuffers=*/
1246         isLHLO ? ValueRange{args.back()}
1247                : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
1248                                           dyn_sizes)},
1249         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1250         GetNParallelLoopsAttrs(nloops),
1251         [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
1252           Value index_op = nested_builder.create<linalg::IndexOp>(
1253               nested_loc, iota_op.iota_dimension());
1254           Value cast_op = nested_builder.create<IndexCastOp>(
1255               nested_loc, index_op,
1256               nested_builder.getIntegerType(
1257                   result_element_type.getIntOrFloatBitWidth()));
1258           if (result_element_type.template isa<FloatType>()) {
1259             cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
1260                                                       result_element_type);
1261           }
1262           nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
1263         });
1264     if (isLHLO)
1265       rewriter.replaceOp(iota_op, llvm::None);
1266     else
1267       rewriter.replaceOp(iota_op, linalg_op.result_tensors());
1268     return success();
1269   }
1270 };
1271 
1272 /// Converts mhlo.concatenate operation to a linalg.generic op.
1273 struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
1274   using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
1275 
matchAndRewritemlir::__anonb93b5a210111::ConcatenateConverter1276   LogicalResult matchAndRewrite(
1277       mhlo::ConcatenateOp op, ArrayRef<Value> args,
1278       ConversionPatternRewriter& rewriter) const override {
1279     // Shortcut the one-operand case, simplifies code below.
1280     if (args.size() == 1) {
1281       rewriter.replaceOp(op, args[0]);
1282       return success();
1283     }
1284 
1285     auto result_type =
1286         this->typeConverter->convertType(op.getResult().getType())
1287             .dyn_cast<RankedTensorType>();
1288     if (!result_type) return failure();
1289 
1290     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1291     uint64_t dim = op.dimension();
1292     int64_t rank = result_type.getRank();
1293     Value zero = b.create<ConstantIndexOp>(0);
1294     SmallVector<Value, 3> sizes;
1295     for (int64_t i = 0; i < rank; ++i) {
1296       sizes.push_back(i == dim ? Value() : b.create<tensor::DimOp>(args[0], i));
1297     }
1298 
1299     // Calculate the size of the concatenated dimension.
1300     Value result_dim_size;
1301     for (auto arg : args) {
1302       Value size = b.create<tensor::DimOp>(arg, dim);
1303       result_dim_size =
1304           result_dim_size ? b.create<AddIOp>(result_dim_size, size) : size;
1305     }
1306     sizes[dim] = result_dim_size;
1307 
1308     // Allocate the output tensor with init_tensor.
1309     SmallVector<Value, 3> dyn_sizes;
1310     for (int64_t i = 0; i < rank; ++i) {
1311       if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]);
1312     }
1313     Value result = b.create<linalg::InitTensorOp>(
1314         dyn_sizes, result_type.getShape(), result_type.getElementType());
1315 
1316     // Generate a generic op to gather the elements of the concatenate. This is
1317     // awkward standalone but allows fusion with other generic ops.
1318     unsigned nloops = result_type.getRank();
1319     auto linalg_op = b.create<linalg::GenericOp>(
1320         /*resultTensorTypes=*/result_type,
1321         /*inputs=*/ValueRange{}, /*outputBuffers=*/result,
1322         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1323         GetNParallelLoopsAttrs(nloops),
1324         [&](OpBuilder& nested_builder, Location loc, ValueRange) {
1325           OpBuilder b = nested_builder;
1326           Value concat_dim_size = zero;
1327           Value result;
1328 
1329           SmallVector<Value, 4> extract_indices;
1330           extract_indices.reserve(nloops);
1331           for (int i = 0; i < nloops; i++) {
1332             extract_indices.push_back(b.create<linalg::IndexOp>(loc, i));
1333           }
1334 
1335           Value index_op = b.create<linalg::IndexOp>(loc, dim);
1336           for (const Value& arg : args) {
1337             Value new_concat_dim_size;
1338             scf::IfOp if_op;
1339             if (&arg != &args.back()) {
1340               // Calculate how far along we have iterated along the concatenate
1341               // dimension. That way we can tell which input to select.
1342               new_concat_dim_size = b.create<AddIOp>(
1343                   loc, concat_dim_size, b.create<tensor::DimOp>(loc, arg, dim));
1344               Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
1345                                            CmpIPredicate::ult, index_op,
1346                                            new_concat_dim_size);
1347               if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
1348                                           cmp, true);
1349               if (result) {
1350                 b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
1351               } else {
1352                 result = if_op->getResults()[0];
1353               }
1354 
1355               b = if_op.getThenBodyBuilder(b.getListener());
1356             }
1357 
1358             // Now adjust the index for the concatenated dimension to fit into
1359             // the selected tensor and do an extract at that position.
1360             extract_indices[dim] =
1361                 b.create<SubIOp>(loc, index_op, concat_dim_size);
1362             Value extract =
1363                 b.create<tensor::ExtractOp>(loc, arg, extract_indices);
1364             b.create<scf::YieldOp>(loc, extract);
1365 
1366             if (if_op) {
1367               b = if_op.getElseBodyBuilder(b.getListener());
1368               concat_dim_size = new_concat_dim_size;
1369             }
1370           }
1371           nested_builder.create<linalg::YieldOp>(loc, result);
1372         });
1373     rewriter.replaceOp(op, linalg_op.result_tensors());
1374     return success();
1375   }
1376 };
1377 
1378 class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> {
1379  public:
1380   using OpConversionPattern::OpConversionPattern;
1381 
matchAndRewrite(lmhlo::ConstOp const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const1382   LogicalResult matchAndRewrite(
1383       lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
1384       ConversionPatternRewriter& rewriter) const final {
1385     Location loc = const_op.getLoc();
1386     auto value_attr = const_op.value().cast<DenseElementsAttr>();
1387     if (value_attr.getType().getRank() != 0) return failure();
1388     Value std_scalar_const =
1389         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
1390     rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const,
1391                                          const_op.getOperand(), llvm::None);
1392     rewriter.eraseOp(const_op);
1393     return success();
1394   }
1395 };
1396 
1397 class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
1398  public:
1399   using OpConversionPattern::OpConversionPattern;
1400 
matchAndRewrite(mhlo::ConstOp const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const1401   LogicalResult matchAndRewrite(
1402       mhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
1403       ConversionPatternRewriter& rewriter) const final {
1404     auto value_attr = const_op.value().cast<DenseElementsAttr>();
1405     auto type =
1406         typeConverter->convertType(const_op.getType()).cast<ShapedType>();
1407     if (type != const_op.getType()) {
1408       // Signedness conversion.
1409       value_attr = value_attr.mapValues(type.getElementType(),
1410                                         [](const APInt& i) { return i; });
1411     }
1412     rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr);
1413     return success();
1414   }
1415 };
1416 
1417 class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
1418  public:
1419   using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
1420 
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1421   LogicalResult matchAndRewrite(
1422       lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
1423       ConversionPatternRewriter& rewriter) const final {
1424     auto loc = reduce_op.getLoc();
1425     lmhlo::ReduceOp::Adaptor adaptor(args);
1426     auto operand_shape =
1427         adaptor.inputs()[0].getType().template dyn_cast<ShapedType>();
1428     if (!operand_shape || !operand_shape.hasRank()) {
1429       return rewriter.notifyMatchFailure(reduce_op, "expects known-rank args");
1430     }
1431 
1432     // First fill the output buffer with the init value.
1433     Value init_value =
1434         rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
1435     rewriter.create<linalg::FillOp>(loc, init_value, adaptor.out()[0]);
1436 
1437     DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
1438     SmallVector<int, 4> reduction_dims;
1439     for (const auto& dim : dimensions_attr.getIntValues()) {
1440       reduction_dims.push_back(dim.getSExtValue());
1441     }
1442 
1443     SmallVector<AffineExpr, 2> src_exprs;
1444     SmallVector<AffineExpr, 2> dst_exprs;
1445     SmallVector<StringRef, 4> types;
1446     for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
1447       bool is_reduced = llvm::is_contained(reduction_dims, i);
1448       types.push_back(is_reduced ? getReductionIteratorTypeName()
1449                                  : getParallelIteratorTypeName());
1450 
1451       src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1452       if (!is_reduced) {
1453         dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1454       }
1455     }
1456 
1457     auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
1458 
1459     auto linalg_op = rewriter.create<linalg::GenericOp>(
1460         loc, /*resultTensorTypes=*/ArrayRef<Type>{},
1461         /*inputs=*/adaptor.inputs(), /*outputBuffers=*/adaptor.out(), maps,
1462         types);
1463     rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
1464                                 linalg_op.region().end());
1465     {
1466       OpBuilder::InsertionGuard region_guard(rewriter);
1467       Block* block = linalg_op.getBody();
1468       rewriter.setInsertionPoint(&block->front());
1469 
1470       // The incoming region is operating on buffers, while linalg.generic
1471       // expects scalar SSA values. Add some allocs around the original op to
1472       // make it compatible.
1473       auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
1474       Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
1475       Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
1476       Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
1477 
1478       // Now turn the existing signature
1479       //   (memref<X>, memref<X>, memref<X>) -> ()
1480       // into
1481       //   (X, X) -> X
1482       TypeConverter::SignatureConversion signature_converter(3);
1483       signature_converter.remapInput(0, alloc_a);
1484       signature_converter.remapInput(1, alloc_b);
1485       signature_converter.remapInput(2, alloc_res);
1486       signature_converter.addInputs(
1487           {arg_type.getElementType(), arg_type.getElementType()});
1488       Block* entry_block = rewriter.applySignatureConversion(
1489           &linalg_op.region(), signature_converter);
1490 
1491       // Store the arguments into the newly allocated buffers.
1492       rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
1493       rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
1494                                        alloc_a);
1495       rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
1496                                        alloc_b);
1497       rewriter.replaceOp(entry_block->getTerminator(), {});
1498 
1499       // Load & yield the result.
1500       rewriter.setInsertionPointToEnd(entry_block);
1501       auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
1502       rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
1503     }
1504 
1505     rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
1506     return success();
1507   }
1508 };
1509 
1510 // TODO(b/156787842): Support the lowering for dynamic shapes.
1511 template <typename OpTy, bool isLHLO = true>
1512 class ReverseConverter
1513     : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1514                                      isLHLO> {
1515  public:
1516   using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1517                                 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1518   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1519     auto result_type =
1520         GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1521     auto nloops = result_type.getRank();
1522     SmallVector<AffineExpr, 2> input_exprs;
1523     input_exprs.reserve(nloops);
1524     for (int i = 0; i < nloops; ++i)
1525       input_exprs.push_back(b->getAffineDimExpr(i));
1526     for (auto dim : op.dimensions()) {
1527       int i = dim.getZExtValue();
1528       if (result_type.isDynamicDim(i)) return {};
1529       int n = result_type.getShape()[i];
1530       input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
1531     }
1532     return {
1533         AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1534         b->getMultiDimIdentityMap(nloops)};
1535   }
1536 };
1537 
1538 template <typename OpTy, bool isLHLO = true>
1539 class SliceConverter : public OpConversionPattern<OpTy> {
1540  public:
1541   using OpConversionPattern<OpTy>::OpConversionPattern;
1542 
matchAndRewrite(OpTy slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1543   LogicalResult matchAndRewrite(
1544       OpTy slice_op, ArrayRef<Value> args,
1545       ConversionPatternRewriter& rewriter) const final {
1546     auto loc = slice_op.getLoc();
1547     auto arg_type = args[0].getType().template dyn_cast<ShapedType>();
1548     if (!arg_type || !arg_type.hasRank()) {
1549       return rewriter.notifyMatchFailure(slice_op, "expects known-rank args");
1550     }
1551 
1552     SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1553     for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
1554       auto start = slice_op.start_indices().template getValue<int64_t>(i);
1555       auto limit = slice_op.limit_indices().template getValue<int64_t>(i);
1556       auto stride = slice_op.strides().template getValue<int64_t>(i);
1557       offsets.push_back(rewriter.getI64IntegerAttr(start));
1558       // Say that there are k elements in total, we have condition:
1559       //   start + (k - 1) * strides <= limit - 1
1560       // ->
1561       //   k <= (limit - 1 - start) / strides + 1
1562       sizes.push_back(
1563           rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1));
1564       strides.push_back(rewriter.getI64IntegerAttr(stride));
1565     }
1566     if (isLHLO) {
1567       auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
1568                                                           sizes, strides);
1569       rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
1570       rewriter.eraseOp(slice_op);
1571     } else {
1572       rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1573           slice_op, args[0], offsets, sizes, strides);
1574     }
1575     return success();
1576   }
1577 };
1578 
1579 class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
1580  public:
1581   using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
1582 
matchAndRewrite(mhlo::DynamicSliceOp dynamic_slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1583   LogicalResult matchAndRewrite(
1584       mhlo::DynamicSliceOp dynamic_slice_op, ArrayRef<Value> args,
1585       ConversionPatternRewriter& rewriter) const final {
1586     auto loc = dynamic_slice_op.getLoc();
1587     mhlo::DynamicSliceOp::Adaptor adaptor(args);
1588     auto arg_type = adaptor.operand().getType().dyn_cast<ShapedType>();
1589     if (!arg_type || !arg_type.hasRank()) {
1590       return rewriter.notifyMatchFailure(dynamic_slice_op,
1591                                          "require known-rank args");
1592     }
1593 
1594     auto index_type = rewriter.getIndexType();
1595     SmallVector<OpFoldResult, 3> start_indices, sizes;
1596     Value zero = rewriter.create<ConstantOp>(
1597         loc, rewriter.getZeroAttr(adaptor.start_indices()[0]
1598                                       .getType()
1599                                       .cast<RankedTensorType>()
1600                                       .getElementType()));
1601     for (auto en : llvm::enumerate(
1602              llvm::zip(adaptor.start_indices(),
1603                        dynamic_slice_op.slice_sizes().getValues<int64_t>()))) {
1604       int64_t size = std::get<1>(en.value());
1605       sizes.push_back(rewriter.getI64IntegerAttr(size));
1606 
1607       // By mhlo.DynamicSlice definition:
1608       //   `start_indices[i] = clamp(start_indices[i],
1609       //       0, operand.dimension_size[i] - size_indices[i])`
1610       Value start_index =
1611           rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
1612       Value ub = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(),
1613                                                       en.index());
1614       // ClampOp lowering does not support index type, so cast it into integer
1615       // type.
1616       ub = rewriter.createOrFold<IndexCastOp>(loc, start_index.getType(), ub);
1617       ub = rewriter.createOrFold<SubIOp>(
1618           loc, ub,
1619           rewriter.create<ConstantOp>(
1620               loc, rewriter.getIntegerAttr(start_index.getType(), size)));
1621       // TODO(hanchung): This is a workaround to use the method because only
1622       // lmhlo version is defined. The implementation in
1623       // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
1624       // to an lmhlo op and call the lmhlo implementation.
1625       start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
1626           loc, start_index.getType(),
1627           ArrayRef<Type>{start_index.getType(), start_index.getType(),
1628                          start_index.getType()},
1629           ArrayRef<Value>{zero, start_index, ub}, &rewriter);
1630       start_indices.push_back(
1631           rewriter.create<IndexCastOp>(loc, index_type, start_index)
1632               .getResult());
1633     }
1634 
1635     int64_t rank = arg_type.getRank();
1636     SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1637 
1638     auto result_type =
1639         this->typeConverter->convertType(dynamic_slice_op.getType())
1640             .cast<RankedTensorType>();
1641 
1642     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1643         dynamic_slice_op, result_type, adaptor.operand(), start_indices, sizes,
1644         strides);
1645     return success();
1646   }
1647 };
1648 
1649 class DynamicUpdateSliceConverter
1650     : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1651  public:
1652   using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
1653 
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1654   LogicalResult matchAndRewrite(
1655       mhlo::DynamicUpdateSliceOp op, ArrayRef<Value> args,
1656       ConversionPatternRewriter& rewriter) const final {
1657     auto loc = op.getLoc();
1658     mhlo::DynamicUpdateSliceOp::Adaptor adaptor(args);
1659     auto operand_type =
1660         adaptor.operand().getType().dyn_cast<RankedTensorType>();
1661     if (!operand_type || !operand_type.hasStaticShape()) {
1662       return rewriter.notifyMatchFailure(
1663           op, "require static ranked type for operand");
1664     }
1665 
1666     auto update_type = adaptor.update().getType().dyn_cast<RankedTensorType>();
1667     if (!update_type || !update_type.hasStaticShape()) {
1668       return rewriter.notifyMatchFailure(
1669           op, "require static ranked type for operand");
1670     }
1671 
1672     // We do not have to clamp sizes because the semantic of `update`
1673     // guarantees that it is always in the bounds. See
1674     // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
1675     SmallVector<OpFoldResult, 3> sizes;
1676     for (auto size : update_type.getShape()) {
1677       sizes.push_back(rewriter.getIndexAttr(size));
1678     }
1679 
1680     auto index_type = rewriter.getIndexType();
1681     SmallVector<OpFoldResult, 3> start_indices;
1682     Type start_index_type = adaptor.start_indices()[0]
1683                                 .getType()
1684                                 .cast<RankedTensorType>()
1685                                 .getElementType();
1686     Value zero = rewriter.create<ConstantOp>(
1687         loc, rewriter.getZeroAttr(start_index_type));
1688     for (auto en : llvm::enumerate(adaptor.start_indices())) {
1689       // By mhlo.DynamicUpdateSlice definition:
1690       //   `start_indices[i] = clamp(start_indices[i],
1691       //       0, operand.dimension_size[i] - update.dimension_size[i])`
1692       Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
1693       Value ub = rewriter.create<ConstantOp>(
1694           loc, rewriter.getIntegerAttr(start_index_type,
1695                                        operand_type.getDimSize(en.index()) -
1696                                            update_type.getDimSize(en.index())));
1697       // TODO(hanchung): This is a workaround to use the method because only
1698       // lmhlo version is defined. The implementation in
1699       // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
1700       // to an lmhlo op and call the lmhlo implementation.
1701       start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
1702           loc, start_index_type,
1703           ArrayRef<Type>{start_index_type, start_index_type, start_index_type},
1704           ArrayRef<Value>{zero, start_index, ub}, &rewriter);
1705       start_indices.push_back(
1706           rewriter.create<IndexCastOp>(loc, index_type, start_index)
1707               .getResult());
1708     }
1709 
1710     int64_t rank = operand_type.getRank();
1711     SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1712     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1713         op, adaptor.update(), adaptor.operand(), start_indices, sizes, strides);
1714     return success();
1715   }
1716 };
1717 
1718 enum class DotOperationType {
1719   kVectorDot = 0,
1720   kMatrixVector = 1,
1721   kMatrixMatrix = 2,
1722   kUnsupported = 3
1723 };
1724 
GetDotOperationType(mhlo::DotOp dot_op)1725 DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
1726   ArrayRef<int64_t> lhs_shape =
1727       dot_op.lhs().getType().cast<ShapedType>().getShape();
1728   ArrayRef<int64_t> rhs_shape =
1729       dot_op.rhs().getType().cast<ShapedType>().getShape();
1730   auto shape_matches = [](int64_t a, int64_t b) {
1731     return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1732            a == b;
1733   };
1734   if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
1735       shape_matches(lhs_shape[0], rhs_shape[0])) {
1736     return DotOperationType::kVectorDot;
1737   }
1738   if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
1739       shape_matches(lhs_shape[1], rhs_shape[0])) {
1740     return DotOperationType::kMatrixVector;
1741   }
1742   if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
1743       shape_matches(lhs_shape[1], rhs_shape[0])) {
1744     return DotOperationType::kMatrixMatrix;
1745   }
1746   return DotOperationType::kUnsupported;
1747 }
1748 
GetDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1749 SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1750                                                  Value lhs, Value rhs,
1751                                                  DotOperationType type) {
1752   SmallVector<Value, 2> dyn_shape;
1753   switch (type) {
1754     case DotOperationType::kMatrixMatrix: {
1755       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1756         dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1757       if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1758         dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1759       break;
1760     }
1761     case DotOperationType::kMatrixVector: {
1762       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1763         dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1764       break;
1765     }
1766     case DotOperationType::kVectorDot:
1767     case DotOperationType::kUnsupported:
1768     default: {
1769       break;
1770     }
1771   }
1772   return dyn_shape;
1773 }
1774 
1775 template <DotOperationType op_type, typename LinalgOp>
1776 class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
1777  public:
1778   using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1779   LogicalResult matchAndRewrite(
1780       mhlo::DotOp op, ArrayRef<Value> args,
1781       ConversionPatternRewriter& rewriter) const final {
1782     if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1783       return failure();
1784     }
1785     if (GetDotOperationType(op) != op_type) return failure();
1786 
1787     mhlo::DotOp::Adaptor adaptor(args);
1788 
1789     Location loc = op.getLoc();
1790     auto output_type = op.getType().cast<ShapedType>();
1791     auto output_el_type = output_type.getElementType();
1792     auto zero_attr = rewriter.getZeroAttr(output_el_type);
1793     Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1794     SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
1795         rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1796     auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
1797     Value zero_tensor =
1798         rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
1799     rewriter.replaceOpWithNewOp<LinalgOp>(
1800         op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
1801         ValueRange{zero_tensor});
1802     return success();
1803   }
1804 };
1805 
GetDotGeneralOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,ShapedType result_type)1806 SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
1807     OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
1808   SmallVector<Value, 8> dyn_shape;
1809   if (result_type.isDynamicDim(0))
1810     dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1811   if (result_type.isDynamicDim(1))
1812     dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 1));
1813   if (result_type.isDynamicDim(2))
1814     dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 2));
1815   return dyn_shape;
1816 }
1817 
1818 class DotGeneralOpOnTensorsConversion
1819     : public OpConversionPattern<mhlo::DotGeneralOp> {
1820  public:
1821   using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1822   LogicalResult matchAndRewrite(
1823       mhlo::DotGeneralOp op, ArrayRef<Value> args,
1824       ConversionPatternRewriter& rewriter) const final {
1825     if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1826       return failure();
1827     }
1828 
1829     mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
1830     auto lhs_bathcing_dims =
1831         Extract1DVector(dim_numbers.lhs_batching_dimensions());
1832     auto rhs_bathcing_dims =
1833         Extract1DVector(dim_numbers.rhs_batching_dimensions());
1834     auto lhs_contracting_dims =
1835         Extract1DVector(dim_numbers.lhs_contracting_dimensions());
1836     auto rhs_contracting_dims =
1837         Extract1DVector(dim_numbers.rhs_contracting_dimensions());
1838     if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
1839       return rewriter.notifyMatchFailure(
1840           op, "expected lhs batching dimensions exactly {0}");
1841     }
1842     if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
1843       return rewriter.notifyMatchFailure(
1844           op, "expected rhs batching dimensions exactly {0}");
1845     }
1846     if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
1847       return rewriter.notifyMatchFailure(
1848           op, "expected lhs contracting dimensions exactly {2}");
1849     }
1850     if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
1851       return rewriter.notifyMatchFailure(
1852           op, "expected rhs contracting dimensions exactly {1}");
1853     }
1854 
1855     mhlo::DotGeneralOp::Adaptor adaptor(args);
1856 
1857     Location loc = op.getLoc();
1858     auto output_type = op.getType().cast<ShapedType>();
1859     auto output_el_type = output_type.getElementType();
1860     SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
1861         rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
1862     auto zero_attr = rewriter.getZeroAttr(output_el_type);
1863     Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1864     auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
1865     Value zero_tensor =
1866         rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
1867     Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
1868         loc, /*resultTensorTypes=*/TypeRange{op.getType()},
1869         /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1870         /*outputBuffers=*/ValueRange{zero_tensor});
1871 
1872     rewriter.replaceOp(op, linalg_op->getResults());
1873     return success();
1874   }
1875 };
1876 
IsInBodyOfLinalgOps(Operation * op)1877 bool IsInBodyOfLinalgOps(Operation* op) {
1878   auto parent_op = op->getParentRegion()->getParentOp();
1879   return parent_op->getDialect() ==
1880          parent_op->getContext()->getLoadedDialect<linalg::LinalgDialect>();
1881 }
1882 
1883 template <typename OpTy>
1884 struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
1885   using OpConversionPattern<OpTy>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::ReduceRegionXLAOpConversion1886   LogicalResult matchAndRewrite(
1887       OpTy op, ArrayRef<Value> args,
1888       ConversionPatternRewriter& rewriter) const final {
1889     if (!IsInBodyOfLinalgOps(op)) {
1890       return failure();
1891     }
1892     if (!op.getResult().getType().template isa<TensorType>()) return failure();
1893     if (llvm::all_of(args, [](Value arg) {
1894           return arg.getType().template isa<TensorType>();
1895         })) {
1896       return failure();
1897     }
1898     Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
1899         op, getElementTypeOrSelf(op.getType()), args, &rewriter);
1900     rewriter.replaceOp(op, result);
1901     return success();
1902   }
1903 };
1904 
GetReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType result_type,ArrayRef<int64_t> reduction_dims)1905 SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
1906     OpBuilder& b, Location loc, Value arg, ShapedType result_type,
1907     ArrayRef<int64_t> reduction_dims) {
1908   llvm::SmallSetVector<int, 4> s;
1909   for (auto dim : reduction_dims) s.insert(dim);
1910 
1911   SmallVector<unsigned, 4> parallel_dims;
1912   SmallVector<Value, 8> dyn_shape;
1913   int rank = arg.getType().cast<RankedTensorType>().getRank();
1914   for (int i = 0, j = 0; i < rank; ++i) {
1915     if (s.count(i)) continue;
1916     if (!result_type.isDynamicDim(j++)) continue;
1917     dyn_shape.push_back(b.create<tensor::DimOp>(loc, arg, i));
1918   }
1919 
1920   return dyn_shape;
1921 }
1922 
1923 class ReduceRegionReturnOpConversion
1924     : public OpConversionPattern<mhlo::ReturnOp> {
1925  public:
1926   using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1927   LogicalResult matchAndRewrite(
1928       mhlo::ReturnOp op, ArrayRef<Value> args,
1929       ConversionPatternRewriter& rewriter) const final {
1930     if (!IsInBodyOfLinalgOps(op)) {
1931       return failure();
1932     }
1933     rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
1934     return success();
1935   }
1936 };
1937 
1938 class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
1939  public:
1940   using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1941   LogicalResult matchAndRewrite(
1942       mhlo::ReduceOp op, ArrayRef<Value> args,
1943       ConversionPatternRewriter& rewriter) const final {
1944     Location loc = op.getLoc();
1945     mhlo::ReduceOp::Adaptor adaptor(args);
1946 
1947     int num_inputs = static_cast<int>(adaptor.inputs().size());
1948     auto src_type = adaptor.inputs()[0].getType().cast<ShapedType>();
1949     int src_rank = src_type.getRank();
1950     if (!src_rank) {
1951       return rewriter.notifyMatchFailure(op, "expects known-rank args");
1952     }
1953 
1954     SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
1955 
1956     SmallVector<Value> inputs, outputs;
1957     SmallVector<AffineMap, 3> indexing_maps;
1958     for (int i = 0; i < num_inputs; ++i) {
1959       Value src = adaptor.inputs()[i];
1960       if (src.getType() != src_type) return failure();
1961 
1962       // Check if init_value is constant. If so, inline the value into the
1963       // region.
1964       Value init_value = adaptor.init_values()[i];
1965       Attribute init_const_val = GetInitValueAsConst(init_value);
1966       if (init_const_val) {
1967         init_value = rewriter.create<ConstantOp>(
1968             init_value.getDefiningOp()->getLoc(), init_const_val);
1969       } else {
1970         init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
1971       }
1972 
1973       inputs.push_back(src);
1974       auto result_type = op.getResult(i).getType().cast<ShapedType>();
1975       SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
1976           rewriter, loc, src, result_type, reduction_dims);
1977       auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape);
1978       Value filled_tensor =
1979           rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
1980               .result();
1981       outputs.push_back(filled_tensor);
1982     }
1983 
1984     // Prepare indexing maps for linalg generic op. The elements are for src
1985     // and dst. Transpose `src` to make the reduction loops be the innermost,
1986     // because it's easier to fully utilize processors.
1987     indexing_maps.append(
1988         num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank,
1989                                                 reduction_dims));
1990 
1991     // The indexing map of `dst` should drop the reduction loops. Since the
1992     // reduction loops now are all in the innermost, drops
1993     // `reduction_dims.size()` dimensions. We don't need an inverse
1994     // permutation here because they are the same.
1995     SmallVector<AffineExpr, 4> exprs;
1996     for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
1997       exprs.push_back(rewriter.getAffineDimExpr(i));
1998     indexing_maps.append(num_inputs,
1999                          AffineMap::get(src_rank, /*symbolCount=*/0, exprs,
2000                                         rewriter.getContext()));
2001 
2002     auto linalg_op = rewriter.create<linalg::GenericOp>(
2003         loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
2004         /*outputBuffers=*/ValueRange{outputs}, indexing_maps,
2005         GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
2006 
2007     // Convert the signature of the body. The reduce op region apply function
2008     // has a signature (lhs, rhs) -> output, all of the same tensor type t.
2009     // This is converted to a function with the same signature but with
2010     // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
2011     // be converted to "(f32, f32, f32)".
2012     Region& region = linalg_op.region();
2013     rewriter.inlineRegionBefore(op.body(), region, region.end());
2014     TypeConverter::SignatureConversion signature_converter(num_inputs * 2);
2015     for (int i = 0; i < num_inputs * 2; ++i)
2016       signature_converter.addInputs(i, src_type.getElementType());
2017     rewriter.applySignatureConversion(&region, signature_converter);
2018     rewriter.replaceOp(op, linalg_op.getResults());
2019     return success();
2020   }
2021 };
2022 
2023 /// Converts mhlo.pad operation to linalg.pad_tensor op.
2024 struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
2025   using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
2026 
matchAndRewritemlir::__anonb93b5a210111::PadOpOnTensorsConversion2027   LogicalResult matchAndRewrite(
2028       mhlo::PadOp op, ArrayRef<Value> args,
2029       ConversionPatternRewriter& rewriter) const override {
2030     mhlo::PadOp::Adaptor adaptor(args);
2031     if (llvm::any_of(
2032             op.interior_padding().getValues<APInt>(),
2033             [](const APInt& int_val) { return int_val.getZExtValue() != 0; })) {
2034       return rewriter.notifyMatchFailure(op, "expected no interior padding");
2035     }
2036 
2037     auto loc = op.getLoc();
2038     Value padding_val =
2039         rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
2040 
2041     const auto& edge_padding_low = op.edge_padding_low();
2042     const auto& edge_padding_high = op.edge_padding_high();
2043     SmallVector<OpFoldResult, 4> low, high;
2044     for (auto it : llvm::zip(edge_padding_low, edge_padding_high)) {
2045       low.push_back(rewriter.createOrFold<ConstantIndexOp>(
2046           loc, std::get<0>(it).getZExtValue()));
2047       high.push_back(rewriter.createOrFold<ConstantIndexOp>(
2048           loc, std::get<1>(it).getZExtValue()));
2049     }
2050     Type result_type = op.getResult().getType();
2051     auto pad_tensor_op = linalg::PadTensorOp::createPadScalarOp(
2052         result_type, adaptor.operand(), padding_val, low, high, loc, rewriter);
2053     rewriter.replaceOp(op, pad_tensor_op.getResult());
2054     return success();
2055   }
2056 };
2057 
2058 /// Converts mhlo.conv operation to linalg named op. This only covers normal
2059 /// convolution cases. The op must have canonical dimension numbers. Depthwise
2060 /// convolution and pointwise convolution are not handled in the conversion.
2061 struct NormalConvOpOnTensorsConversion
2062     : public OpConversionPattern<mhlo::ConvOp> {
2063   using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
2064 
matchAndRewritemlir::__anonb93b5a210111::NormalConvOpOnTensorsConversion2065   LogicalResult matchAndRewrite(
2066       mhlo::ConvOp op, ArrayRef<Value> args,
2067       ConversionPatternRewriter& rewriter) const override {
2068     if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
2069     if (op.feature_group_count() != 1u) return failure();
2070 
2071     mhlo::ConvOp::Adaptor adaptor(args);
2072     Location loc = op.getLoc();
2073     Value input = adaptor.lhs();
2074     Value filter = adaptor.rhs();
2075     auto result_type = op.getResult().getType().cast<ShapedType>();
2076     int64_t rank = result_type.getRank();
2077 
2078     // Check if padding is zero.
2079     DenseIntElementsAttr padding = op.paddingAttr();
2080     if (padding && !isSplatValue(*op.padding(), 0)) {
2081       return rewriter.notifyMatchFailure(op, "expected no padding");
2082     }
2083 
2084     // The output shape is N spatial_dims F.
2085     SmallVector<Value, 8> dyn_sizes;
2086     if (result_type.isDynamicDim(0)) {
2087       dyn_sizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
2088     }
2089     for (int64_t i = 1, e = rank - 1; i < e; ++i) {
2090       if (result_type.isDynamicDim(i)) {
2091         return rewriter.notifyMatchFailure(
2092             op, "expected output spatial dims to be static shapes");
2093       }
2094     }
2095     if (result_type.isDynamicDim(rank - 1)) {
2096       dyn_sizes.push_back(
2097           rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
2098     }
2099     Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2100         loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
2101     auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2102     Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2103     Value zero_tensor =
2104         rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2105     linalg::LinalgOp res;
2106     Attribute strides = op.window_stridesAttr();
2107     // TODO(ataei): Only support dilated kernel right now. We need to consider
2108     // input dilation for deconvolution cases.
2109     Attribute dilations = op.rhs_dilationAttr();
2110     switch (rank) {
2111       case 3: {
2112         res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
2113             loc, result_type, ValueRange{input, filter},
2114             ValueRange{zero_tensor}, dilations, strides);
2115         break;
2116       }
2117       case 4: {
2118         res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
2119             loc, result_type, ValueRange{input, filter},
2120             ValueRange{zero_tensor}, dilations, strides);
2121         break;
2122       }
2123       case 5: {
2124         res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
2125             loc, result_type, ValueRange{input, filter},
2126             ValueRange{zero_tensor}, dilations, strides);
2127         break;
2128       }
2129       default:
2130         return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
2131     }
2132     rewriter.replaceOp(op, res.getOperation()->getResults());
2133     return success();
2134   }
2135 };
2136 
2137 /// Converts mhlo.convolution operation to
2138 /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
2139 /// depthwise_conv_2d_input_nhwc_filter_hwc op.
2140 struct DepthwiseConvOpOnTensorsConversion
2141     : public OpConversionPattern<mhlo::ConvOp> {
2142   using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
2143 
matchAndRewritemlir::__anonb93b5a210111::DepthwiseConvOpOnTensorsConversion2144   LogicalResult matchAndRewrite(
2145       mhlo::ConvOp op, ArrayRef<Value> args,
2146       ConversionPatternRewriter& rewriter) const override {
2147     if (op.batch_group_count() != 1) return failure();
2148 
2149     if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2150       return rewriter.notifyMatchFailure(op,
2151                                          "non-zero padding unsupported yet");
2152     }
2153 
2154     if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1))) {
2155       return rewriter.notifyMatchFailure(
2156           op, "non-one lhs- dialation unsupported yet");
2157     }
2158 
2159     if (const mhlo::ConvDimensionNumbers& dimension_numbers =
2160             op.dimension_numbers()) {
2161       // Make sure that this is 2-D convolution.
2162       const auto spatial_rank =
2163           llvm::size(dimension_numbers.input_spatial_dimensions());
2164       if (spatial_rank != 2) {
2165         return rewriter.notifyMatchFailure(op,
2166                                            "only support 2-D cases for now");
2167       }
2168 
2169       // Make sure that this is depthwise convolution.
2170       int64_t input_feature_dim =
2171           dimension_numbers.input_feature_dimension().getInt();
2172       int64_t input_feature_count =
2173           op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
2174       if (op.feature_group_count() != input_feature_count) {
2175         return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
2176       }
2177 
2178       // Make sure that this convolution has a canonical form.
2179       if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
2180         return rewriter.notifyMatchFailure(op, "does not have canonical form");
2181       }
2182     }
2183 
2184     DenseIntElementsAttr window_strides;
2185     if (op.window_strides()) {
2186       window_strides = op.window_strides().getValue();
2187     } else {
2188       window_strides = rewriter.getI64VectorAttr({1, 1});
2189     }
2190 
2191     DenseIntElementsAttr rhs_dilation;
2192     if (op.rhs_dilation()) {
2193       rhs_dilation = op.rhs_dilation().getValue();
2194     } else {
2195       rhs_dilation = rewriter.getI64VectorAttr({1, 1});
2196     }
2197 
2198     mhlo::ConvOp::Adaptor adaptor(args);
2199     Location loc = op.getLoc();
2200     Value input = adaptor.lhs();
2201     Value filter = adaptor.rhs();
2202     auto result_type = op.getResult().getType().cast<RankedTensorType>();
2203     if (!result_type.hasStaticShape()) {
2204       return rewriter.notifyMatchFailure(op,
2205                                          "expected output has static shapes");
2206     }
2207 
2208     auto filter_dims =
2209         llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
2210 
2211     auto get_indices_vector = [](int start, int end) {
2212       return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
2213     };
2214 
2215     if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
2216       // For cases where channel multiplier != 1
2217       auto output_dims = result_type.getShape();
2218       auto channel_multiplier = filter_dims[3];
2219       SmallVector<int64_t> reshaped_output_dims;
2220       reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
2221       reshaped_output_dims.push_back(channel_multiplier);
2222       reshaped_output_dims[3] /= channel_multiplier;
2223 
2224       Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2225           loc, reshaped_output_dims, result_type.getElementType());
2226       auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2227       Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2228       Value zero_tensor =
2229           rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2230 
2231       auto reshaped_output_type = RankedTensorType::get(
2232           reshaped_output_dims, result_type.getElementType());
2233       auto conv = rewriter.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
2234           op.getLoc(), reshaped_output_type, ValueRange{input, filter},
2235           ValueRange{zero_tensor}, rhs_dilation, window_strides);
2236 
2237       // Create a Linalg reshape op that converts the output from 5 dimensions
2238       // into 4 dimensions (by collapsing the last two dimensions). This is
2239       // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
2240       // 5 dimensions for the output.
2241       SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
2242           get_indices_vector(0, 1), get_indices_vector(1, 2),
2243           get_indices_vector(2, 3), get_indices_vector(3, 5)};
2244       rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
2245           op, result_type, conv.getResult(0), collapsed_dim_list);
2246     } else {
2247       // For cases where channel multiplier == 1
2248       Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2249           loc, result_type.getShape(), result_type.getElementType());
2250       auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2251       Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2252       Value zero_tensor =
2253           rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2254 
2255       // Create a Linalg reshape op that converts the filter from 4 dimensions
2256       // into 3 dimensions (by droping the unit dimension). This is needed
2257       // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
2258       // dimensions for the filter.
2259 
2260       filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
2261       filter_dims.pop_back();
2262 
2263       RankedTensorType filter_shape =
2264           RankedTensorType::get(filter_dims, op.getType().getElementType());
2265 
2266       SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
2267           get_indices_vector(0, 1), get_indices_vector(1, 2),
2268           get_indices_vector(2, 4)};
2269 
2270       Value reshaped_filter = rewriter.create<linalg::TensorCollapseShapeOp>(
2271           loc, filter_shape, filter, collapsed_dim_list);
2272 
2273       rewriter.replaceOpWithNewOp<linalg::DepthwiseConvInputNHWCFilterHWCOp>(
2274           op, result_type, ValueRange{input, reshaped_filter},
2275           ValueRange{zero_tensor}, rhs_dilation, window_strides);
2276     }
2277 
2278     return success();
2279   }
2280 };
2281 
2282 struct ReduceWindowOpOnTensorsConversion
2283     : public OpConversionPattern<mhlo::ReduceWindowOp> {
2284   using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
2285 
2286   /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
2287   /// the pooling is determined based on the body of the reduce window
2288   /// operation. This class enumerates the different variants.
2289   enum class PoolingType {
2290     kInvalid,
2291     k2DMin,
2292     k3DMin,
2293     k2DMax,
2294     k3DMax,
2295     k2DAdd,
2296     k3DAdd,
2297   };
2298 
getPoolingTypemlir::__anonb93b5a210111::ReduceWindowOpOnTensorsConversion2299   static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op,
2300                                     int result_index) {
2301     auto rank =
2302         reduce_op.getResultTypes()[result_index].cast<ShapedType>().getRank();
2303     if (Operation* op = reduce_op.getReductionOp(result_index)) {
2304       if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
2305       if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
2306       if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
2307       if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
2308       if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
2309       if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
2310     }
2311     return PoolingType::kInvalid;
2312   }
2313 
matchAndRewritemlir::__anonb93b5a210111::ReduceWindowOpOnTensorsConversion2314   LogicalResult matchAndRewrite(
2315       mhlo::ReduceWindowOp op, ArrayRef<Value> args,
2316       ConversionPatternRewriter& rewriter) const override {
2317     auto loc = op.getLoc();
2318     int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
2319     if (rank != 4 && rank != 5) {
2320       return rewriter.notifyMatchFailure(
2321           op, "expected NHWC/NDHWC pooling-based op");
2322     }
2323 
2324     if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2325       return rewriter.notifyMatchFailure(op, "require paddings are all zero");
2326     }
2327 
2328     int last_dim = rank - 1;
2329     SmallVector<int64_t, 2> fake_window_shapes;
2330     for (int i = 1; i < last_dim; ++i) {
2331       fake_window_shapes.push_back(op.window_dimensions().getValue<int64_t>(i));
2332     }
2333 
2334     if (op.window_strides() &&
2335         (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
2336          op.window_strides().getValue().getValue<int64_t>(last_dim) != 1)) {
2337       return rewriter.notifyMatchFailure(
2338           op, "expected window_strides to be [1,x,y,(z),1]");
2339     }
2340     if (op.window_dimensions() &&
2341         (op.window_dimensions().getValue<int64_t>(0) != 1 ||
2342          op.window_dimensions().getValue<int64_t>(last_dim) != 1)) {
2343       return rewriter.notifyMatchFailure(
2344           op, "expected window_dimensions to be [1,x,y,(z),1]");
2345     }
2346 
2347     Attribute strides;
2348     SmallVector<int64_t> vec;
2349     if (op.window_stridesAttr()) {
2350       for (int i = 1; i < last_dim; ++i) {
2351         vec.push_back(op.window_strides().getValue().getValue<int64_t>(i));
2352       }
2353     } else {
2354       vec.assign(rank - 2, 1);
2355     }
2356     strides = rewriter.getI64VectorAttr(vec);
2357 
2358     Attribute dilations;
2359     vec.clear();
2360     if (op.window_dilations()) {
2361       for (int i = 1; i < last_dim; ++i) {
2362         vec.push_back(op.window_dilations().getValue().getValue<int64_t>(i));
2363       }
2364     } else {
2365       vec.assign(rank - 2, 1);
2366     }
2367     dilations = rewriter.getI64VectorAttr(vec);
2368 
2369     SmallVector<Value> pooling_ops;
2370 
2371     ArrayRef<Value> inputs = args.take_front(op.inputs().size());
2372     ArrayRef<Value> init_values = args.drop_front(op.inputs().size());
2373     for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
2374       OpResult result = std::get<0>(it);
2375       Value input = std::get<1>(it);
2376       Value init_value = std::get<2>(it);
2377       auto result_type = result.getType().cast<ShapedType>();
2378       if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
2379         return rewriter.notifyMatchFailure(op,
2380                                            "expected element type to be f32");
2381       }
2382 
2383       // Create a fake window dimension.
2384       auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
2385           loc, fake_window_shapes, result_type.getElementType());
2386 
2387       SmallVector<Value> result_dynamic_dims;
2388       for (auto& en : llvm::enumerate(result_type.getShape())) {
2389         if (en.value() != ShapedType::kDynamicSize) continue;
2390         Value dim_size = rewriter.create<tensor::DimOp>(loc, input, en.index());
2391         if (en.index() == 0 || en.index() == rank - 1) {
2392           // batch dims and channel dims can be derived from input dims
2393           // directly.
2394           result_dynamic_dims.push_back(dim_size);
2395         } else {
2396           auto i = en.index() - 1;
2397           auto stride =
2398               strides.cast<DenseIntElementsAttr>().getValue<int64_t>(i);
2399           auto dilation =
2400               dilations.cast<DenseIntElementsAttr>().getValue<int64_t>(i);
2401           // let j = i * stride
2402           // output[i] = reduce( input[j, j + window_size * dilation) )
2403           Value offset = rewriter.create<ConstantIndexOp>(
2404               loc, fake_window_shapes[i] * dilation);
2405           dim_size = rewriter.create<SubIOp>(loc, dim_size, offset);
2406           dim_size = rewriter.create<UnsignedDivIOp>(
2407               loc, dim_size, rewriter.create<ConstantIndexOp>(loc, stride));
2408           dim_size = rewriter.create<AddIOp>(
2409               loc, dim_size, rewriter.create<ConstantIndexOp>(loc, 1));
2410           result_dynamic_dims.push_back(dim_size);
2411         }
2412       }
2413       Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2414           loc, result_dynamic_dims, result_type.getShape(),
2415           result_type.getElementType());
2416 
2417       init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
2418       Value filled_init_tensor =
2419           rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
2420               .getResult(0);
2421       auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
2422         return cast<linalg::LinalgOp>(
2423             rewriter
2424                 .create<std::remove_pointer_t<decltype(type_ptr)>>(
2425                     loc, ArrayRef<Type>{result_type},
2426                     ValueRange{input, fake_window_dims.getResult()},
2427                     filled_init_tensor, strides, dilations)
2428                 .getOperation());
2429       };
2430       linalg::LinalgOp pooling_op;
2431       PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
2432       switch (pooling_type) {
2433         case PoolingType::k2DMin: {
2434           pooling_op =
2435               create_op(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
2436           break;
2437         }
2438         case PoolingType::k3DMin: {
2439           pooling_op =
2440               create_op(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
2441           break;
2442         }
2443         case PoolingType::k2DMax: {
2444           pooling_op =
2445               create_op(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
2446           break;
2447         }
2448         case PoolingType::k3DMax: {
2449           pooling_op =
2450               create_op(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
2451           break;
2452         }
2453         case PoolingType::k2DAdd: {
2454           pooling_op =
2455               create_op(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
2456           break;
2457         }
2458         case PoolingType::k3DAdd: {
2459           pooling_op =
2460               create_op(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
2461           break;
2462         }
2463         case PoolingType::kInvalid:
2464           return rewriter.notifyMatchFailure(op, "unknown reduction operation");
2465       }
2466       pooling_ops.push_back(pooling_op->getResult(0));
2467     }
2468     rewriter.replaceOp(op, pooling_ops);
2469     return success();
2470   }
2471 };
2472 
2473 /// Converts xla-hlo.torch_index_select op to a linalg.generic op.
2474 struct TorchIndexSelectOpOnTensorsConversion
2475     : public OpConversionPattern<mhlo::TorchIndexSelectOp> {
2476   using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
2477 
matchAndRewritemlir::__anonb93b5a210111::TorchIndexSelectOpOnTensorsConversion2478   LogicalResult matchAndRewrite(
2479       mhlo::TorchIndexSelectOp op, ArrayRef<Value> args,
2480       ConversionPatternRewriter& rewriter) const final {
2481     mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
2482     int axis = static_cast<int>(op.dim());
2483     int batch = static_cast<int>(op.batch_dims());
2484     auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
2485     int num_indices = static_cast<int>(index_shaped_type.getRank());
2486     auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
2487     if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
2488     if (batch < 0) batch += num_indices;
2489 
2490     Location loc = op.getLoc();
2491     auto result_type =
2492         this->typeConverter->convertType(op.getResult().getType())
2493             .cast<ShapedType>();
2494     int rank = static_cast<int>(result_type.getRank());
2495 
2496     SmallVector<AffineMap, 2> indexing_maps;
2497     SmallVector<AffineExpr, 4> exprs;
2498     for (int i = 0; i < batch; ++i) {
2499       exprs.push_back(rewriter.getAffineDimExpr(i));
2500     }
2501     for (int i = 0, e = num_indices - batch; i < e; ++i) {
2502       exprs.push_back(rewriter.getAffineDimExpr(axis + i));
2503     }
2504     indexing_maps.emplace_back(
2505         AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
2506     indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
2507 
2508     // The output shape is
2509     //   `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
2510     SmallVector<Value, 4> dyn_sizes;
2511     for (int i = 0; i < rank; ++i) {
2512       if (!result_type.isDynamicDim(i)) continue;
2513       if (i < axis) {
2514         dyn_sizes.push_back(
2515             rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
2516       } else if (i < (axis + num_indices - batch)) {
2517         int idx = i - axis + batch;
2518         dyn_sizes.push_back(
2519             rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
2520       } else {
2521         int idx = i - (axis + num_indices - batch) + axis + 1;
2522         dyn_sizes.push_back(
2523             rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
2524       }
2525     }
2526     Value init_op = rewriter.create<linalg::InitTensorOp>(
2527         loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
2528     auto linalg_op = rewriter.create<linalg::GenericOp>(
2529         loc, /*resultTensors=*/ArrayRef<Type>{result_type},
2530         /*inputs=*/adaptor.index(),
2531         /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
2532 
2533     SmallVector<Type, 4> body_arg_types;
2534     SmallVector<Value, 2> linalg_op_args = {adaptor.index()};
2535     // Add a block to the region.
2536     auto* region = &linalg_op.region();
2537     auto* block = rewriter.createBlock(region, region->end());
2538     for (auto block_args : linalg_op_args) {
2539       body_arg_types.push_back(
2540           block_args.getType().cast<ShapedType>().getElementType());
2541     }
2542     block->addArguments(body_arg_types);
2543     block->addArguments(result_type.getElementType());
2544     OpBuilder::InsertionGuard guard(rewriter);
2545     rewriter.setInsertionPointToEnd(block);
2546 
2547     Value casted_value = rewriter.create<IndexCastOp>(
2548         loc, block->getArgument(0), rewriter.getIndexType());
2549 
2550     SmallVector<Value, 4> indices;
2551     for (int i = 0; i < axis; ++i) {
2552       indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2553     }
2554     indices.push_back(casted_value);
2555     for (int i = axis + num_indices - batch; i < rank; ++i) {
2556       indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2557     }
2558     Value res =
2559         rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
2560     rewriter.create<linalg::YieldOp>(loc, res);
2561 
2562     rewriter.replaceOp(op, linalg_op.getResults());
2563     return success();
2564   }
2565 };
2566 
2567 struct ScatterUpdateOnTensorsConversion
2568     : public OpConversionPattern<mhlo::ScatterOp> {
2569   using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
2570 
matchAndRewritemlir::__anonb93b5a210111::ScatterUpdateOnTensorsConversion2571   LogicalResult matchAndRewrite(
2572       mhlo::ScatterOp op, ArrayRef<Value> args,
2573       ConversionPatternRewriter& rewriter) const final {
2574     mhlo::ScatterOp::Adaptor adaptor(args);
2575 
2576     // Check if it is a tensor_scatter_nd_update-like op.
2577     auto& body_ops = op.getRegion().front().getOperations();
2578     if (body_ops.size() != 1) return failure();
2579     auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
2580     if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
2581 
2582     auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
2583     auto indices_ty =
2584         adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
2585     if (!operand_ty || !indices_ty) return failure();
2586 
2587     // Linalg operations put all the computation to the innermost loop. Since we
2588     // also iterate over scatter_indices() with some loops, we can only check
2589     // one scatter index in one iteration. If there are multiple indices (ie,
2590     // the index depth is greater than 1), we don't have a way to keep the
2591     // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
2592     // we should use the update value only if (i == 0 and j == 1). However, we
2593     // can not get both indices in one iteration unless we pack them together.
2594     auto index_vector_dim =
2595         op.scatter_dimension_numbers().index_vector_dim().getInt();
2596     if (indices_ty.getDimSize(index_vector_dim) != 1)
2597       return rewriter.notifyMatchFailure(op, "require index depth to be 1");
2598     if (index_vector_dim != indices_ty.getRank() - 1) {
2599       return rewriter.notifyMatchFailure(
2600           op, "require index_vector_dim to be the last dim");
2601     }
2602 
2603     // One of indices dims is index depth vector.
2604     int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
2605     SmallVector<AffineMap, 3> indexing_maps;
2606     {
2607       SmallVector<AffineExpr> exprs;
2608       for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
2609         exprs.push_back(rewriter.getAffineDimExpr(i));
2610       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2611                                              rewriter.getContext()));
2612     }
2613     {
2614       SmallVector<AffineExpr> exprs;
2615       for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
2616         exprs.push_back(rewriter.getAffineDimExpr(i));
2617       // The index depth is 1.
2618       exprs.push_back(rewriter.getAffineConstantExpr(0));
2619       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2620                                              rewriter.getContext()));
2621 
2622       exprs.pop_back();
2623       auto update_window_dims =
2624           Extract1DVector(op.scatter_dimension_numbers().update_window_dims());
2625       for (auto d : update_window_dims)
2626         exprs.push_back(rewriter.getAffineDimExpr(d));
2627       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2628                                              rewriter.getContext()));
2629     }
2630     indexing_maps.push_back(indexing_maps.front());
2631 
2632     auto result_ty = this->typeConverter->convertType(op.getResult().getType())
2633                          .cast<ShapedType>();
2634     auto scatter_dims_to_operand_dims = Extract1DVector(
2635         op.scatter_dimension_numbers().scatter_dims_to_operand_dims());
2636     assert(scatter_dims_to_operand_dims.size() == 1);
2637     // Do not need init_tensor because we'd like to initialize the output as
2638     // operand.
2639     auto linalg_op = rewriter.create<linalg::GenericOp>(
2640         op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
2641         /*inputs=*/
2642         ValueRange{adaptor.operand(), adaptor.scatter_indices(),
2643                    adaptor.updates()},
2644         /*outputs=*/adaptor.operand(), indexing_maps,
2645         GetNParallelLoopsAttrs(nloops),
2646         [&](OpBuilder& b, Location loc, ValueRange args) {
2647           Value cmp_idx =
2648               b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
2649           Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]);
2650           Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq,
2651                                         cmp_idx, idx);
2652           // Use the output arg, so some update values won't be init value
2653           // again.
2654           Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2],
2655                                          args[3]);
2656           b.create<linalg::YieldOp>(loc, res);
2657         });
2658     rewriter.replaceOp(op, linalg_op.getResults());
2659     return success();
2660   }
2661 };
2662 
populateLHLOToLinalgConversionPattern(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList * patterns)2663 void populateLHLOToLinalgConversionPattern(MLIRContext* context,
2664                                            TypeConverter& typeConverter,
2665                                            OwningRewritePatternList* patterns) {
2666   // clang-format off
2667   patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
2668                    ConstConverterBuffer,
2669                    ConvToLinalgConverter,
2670                    IotaConverter<lmhlo::IotaOp>,
2671                    LhloBroadcastInDimConverter,
2672                    PointwiseToLinalgConverter<lmhlo::AbsOp>,
2673                    PointwiseToLinalgConverter<lmhlo::AddOp>,
2674                    PointwiseToLinalgConverter<lmhlo::AndOp>,
2675                    PointwiseToLinalgConverter<lmhlo::Atan2Op>,
2676                    PointwiseToLinalgConverter<lmhlo::CeilOp>,
2677                    PointwiseToLinalgConverter<lmhlo::ClampOp>,
2678                    PointwiseToLinalgConverter<lmhlo::CompareOp>,
2679                    PointwiseToLinalgConverter<lmhlo::ComplexOp>,
2680                    PointwiseToLinalgConverter<lmhlo::ConvertOp>,
2681                    // TODO(ataei): Remove this pattern, CopyOp is folded away.
2682                    PointwiseToLinalgConverter<lmhlo::CopyOp>,
2683                    PointwiseToLinalgConverter<lmhlo::CosOp>,
2684                    PointwiseToLinalgConverter<lmhlo::DivOp>,
2685                    PointwiseToLinalgConverter<lmhlo::ExpOp>,
2686                    PointwiseToLinalgConverter<lmhlo::Expm1Op>,
2687                    PointwiseToLinalgConverter<lmhlo::FloorOp>,
2688                    PointwiseToLinalgConverter<lmhlo::ImagOp>,
2689                    PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
2690                    PointwiseToLinalgConverter<lmhlo::LogOp>,
2691                    PointwiseToLinalgConverter<lmhlo::LogisticOp>,
2692                    PointwiseToLinalgConverter<lmhlo::Log1pOp>,
2693                    PointwiseToLinalgConverter<lmhlo::MaxOp>,
2694                    PointwiseToLinalgConverter<lmhlo::MinOp>,
2695                    PointwiseToLinalgConverter<lmhlo::MulOp>,
2696                    PointwiseToLinalgConverter<lmhlo::NegOp>,
2697                    PointwiseToLinalgConverter<lmhlo::NotOp>,
2698                    PointwiseToLinalgConverter<lmhlo::OrOp>,
2699                    PointwiseToLinalgConverter<lmhlo::PowOp>,
2700                    PointwiseToLinalgConverter<lmhlo::RealOp>,
2701                    PointwiseToLinalgConverter<lmhlo::RemOp>,
2702                    PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
2703                    PointwiseToLinalgConverter<lmhlo::SelectOp>,
2704                    PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
2705                    PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
2706                    PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
2707                    PointwiseToLinalgConverter<lmhlo::SignOp>,
2708                    PointwiseToLinalgConverter<lmhlo::SinOp>,
2709                    PointwiseToLinalgConverter<lmhlo::SqrtOp>,
2710                    PointwiseToLinalgConverter<lmhlo::SubOp>,
2711                    PointwiseToLinalgConverter<lmhlo::TanhOp>,
2712                    PointwiseToLinalgConverter<lmhlo::XorOp>,
2713                    ReduceConverter,
2714                    ReshapeOpConverter<lmhlo::ReshapeOp>,
2715                    ReverseConverter<lmhlo::ReverseOp>,
2716                    ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
2717                    ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
2718                    SliceConverter<lmhlo::SliceOp>,
2719                    TransposeConverter<lmhlo::TransposeOp>
2720                   >(typeConverter, context);
2721   // clang-format on
2722 }
2723 
2724 struct ComputeReshapeShapeConversion
2725     : public OpConversionPattern<mhlo::ComputeReshapeShapeOp> {
2726   using OpConversionPattern<mhlo::ComputeReshapeShapeOp>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::ComputeReshapeShapeConversion2727   LogicalResult matchAndRewrite(
2728       mhlo::ComputeReshapeShapeOp op, ArrayRef<Value> args,
2729       ConversionPatternRewriter& rewriter) const final {
2730     auto loc = op.getLoc();
2731     auto ctx = op->getContext();
2732     Value neg_one = rewriter.create<ConstantIndexOp>(loc, -1);
2733     auto index_type = rewriter.getIndexType();
2734     auto num_elements = args[0];
2735     auto target_shape_type = args[1].getType().cast<ShapedType>();
2736     auto extent_type =
2737         shape::getExtentTensorType(ctx, target_shape_type.getDimSize(0));
2738 
2739     // Calculate the computed actual extent for a possible dynamic extent.
2740     auto new_shape =
2741         target_shape_type.getElementType().isIndex()
2742             ? args[1]
2743             : rewriter.create<IndexCastOp>(loc, extent_type, args[1]);
2744     Value new_shape_rank =
2745         rewriter.create<shape::RankOp>(loc, index_type, new_shape);
2746     // The product begins with a -1 seed which will cancel out a -1 extent in
2747     // the input shape if there is one. If there is not, this computed result
2748     // will never be used, so it's okay to compute a negative number of
2749     // elements.
2750     auto accounted_num_els =
2751         rewriter.create<shape::ReduceOp>(loc, new_shape, neg_one);
2752     {
2753       PatternRewriter::InsertionGuard g(rewriter);
2754       rewriter.setInsertionPointToEnd(accounted_num_els.getBody());
2755       Value lhs = accounted_num_els.getBody()->getArgument(1);
2756       Value rhs = accounted_num_els.getBody()->getArgument(2);
2757       rewriter.create<shape::YieldOp>(
2758           loc, rewriter.create<MulIOp>(loc, lhs, rhs).result());
2759     }
2760     Value missing_dim_val = rewriter.create<UnsignedDivIOp>(
2761         loc, num_elements, accounted_num_els->getResult(0));
2762 
2763     // Create the final target shape with a possible dynamic extent replace with
2764     // the calculated extent.
2765     SmallVector<Value> dynamic_extent;
2766     if (!target_shape_type.hasStaticShape())
2767       dynamic_extent.push_back(new_shape_rank);
2768     auto gen = rewriter.create<tensor::GenerateOp>(
2769         loc, target_shape_type, dynamic_extent,
2770         [&](OpBuilder& b, Location loc, ValueRange indices) {
2771           Value extent = b.create<shape::GetExtentOp>(loc, index_type,
2772                                                       new_shape, indices[0]);
2773           Value use_missing_dim_val =
2774               b.create<CmpIOp>(loc, CmpIPredicate::eq, extent, neg_one);
2775           Value dim_val = b.create<SelectOp>(loc, use_missing_dim_val,
2776                                              missing_dim_val, extent);
2777           dim_val = target_shape_type.getElementType().isIndex()
2778                         ? dim_val
2779                         : b.create<IndexCastOp>(
2780                               loc, target_shape_type.getElementType(), dim_val);
2781           b.create<tensor::YieldOp>(loc, dim_val);
2782         });
2783     rewriter.replaceOp(op, gen.result());
2784 
2785     return success();
2786   }
2787 };
2788 
2789 struct CstrReshapableConversion
2790     : public OpConversionPattern<mhlo::CstrReshapableOp> {
2791   using OpConversionPattern<mhlo::CstrReshapableOp>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::CstrReshapableConversion2792   LogicalResult matchAndRewrite(
2793       mhlo::CstrReshapableOp op, ArrayRef<Value> args,
2794       ConversionPatternRewriter& rewriter) const final {
2795     auto loc = op.getLoc();
2796     auto ctx = op->getContext();
2797     Value neg_one = rewriter.create<ConstantIndexOp>(loc, -1);
2798     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
2799     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
2800     Value two = rewriter.create<ConstantIndexOp>(loc, 2);
2801     auto num_elements = args[0];
2802     auto target_shape_type = args[1].getType().cast<ShapedType>();
2803     auto extent_type =
2804         shape::getExtentTensorType(ctx, target_shape_type.getDimSize(0));
2805 
2806     // Calculate the computed actual extent for a possible dynamic extent.
2807     auto new_shape =
2808         target_shape_type.getElementType().isIndex()
2809             ? args[1]
2810             : rewriter.create<IndexCastOp>(loc, extent_type, args[1]);
2811     auto reduction = rewriter.create<shape::ReduceOp>(
2812         loc, new_shape, llvm::makeArrayRef({one, zero, zero}));
2813     {
2814       PatternRewriter::InsertionGuard g(rewriter);
2815       auto body = reduction.getBody();
2816       rewriter.setInsertionPointToEnd(body);
2817       Value extent = body->getArgument(1);
2818       Value is_dynamic =
2819           rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, neg_one, extent);
2820       Value is_invalid =
2821           rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, extent, neg_one);
2822       Value total_dynamic = rewriter.create<AddIOp>(
2823           loc, rewriter.create<SelectOp>(loc, is_dynamic, one, zero),
2824           body->getArgument(3));
2825       Value total_invalid = rewriter.create<AddIOp>(
2826           loc, rewriter.create<SelectOp>(loc, is_invalid, one, zero),
2827           body->getArgument(4));
2828       Value extent_or_one =
2829           rewriter.create<SelectOp>(loc, is_dynamic, one, extent);
2830       Value total_elements =
2831           rewriter.create<MulIOp>(loc, extent_or_one, body->getArgument(2));
2832       rewriter.create<shape::YieldOp>(
2833           loc,
2834           llvm::makeArrayRef({total_elements, total_dynamic, total_invalid}));
2835     }
2836     Value is_divisible = rewriter.create<CmpIOp>(
2837         loc, CmpIPredicate::eq, zero,
2838         rewriter.create<SignedRemIOp>(loc, num_elements,
2839                                       reduction->getResult(0)));
2840     Value acceptably_dynamic = rewriter.create<CmpIOp>(
2841         loc, CmpIPredicate::ult, two, reduction->getResult(1));
2842     Value no_invalid = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, zero,
2843                                                reduction->getResult(0));
2844 
2845     Value all_passing = rewriter.create<AndOp>(
2846         loc, is_divisible,
2847         rewriter.create<AndOp>(loc, acceptably_dynamic, no_invalid));
2848 
2849     rewriter.replaceOpWithNewOp<shape::CstrRequireOp>(
2850         op, all_passing, "Required valid reshape shape input");
2851 
2852     return success();
2853   }
2854 };
2855 
2856 // Converts LHLO ops to Linalg generic.
2857 // Sample result for lmhlo::AddOp.
2858 //
2859 // "lmhlo.add"(%arg1, %arg2, %out) :
2860 //      (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
2861 //
2862 // will be converted to
2863 //
2864 // #map0 = (d0, d1) -> (d0, d1)
2865 // "linalg.generic"(%arg1, %arg2, %out) ( {
2866 //   ^bb0(%arg4: f32, %arg5: f32):
2867 //     %0 = addf %arg4, %arg5 : f32
2868 //     "linalg.yield"(%0) : (f32) -> ()
2869 // }) {
2870 //     indexing_maps = [#map0, #map0, #map0],
2871 //     iterator_types = ["parallel", "parallel"],
2872 // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
2873 struct LhloLegalizeToLinalgPass
2874     : public lmhlo::LhloLegalizeToLinalgPassBase<LhloLegalizeToLinalgPass> {
getDependentDialectsmlir::__anonb93b5a210111::LhloLegalizeToLinalgPass2875   void getDependentDialects(DialectRegistry& registry) const override {
2876     registry
2877         .insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
2878                 math::MathDialect, memref::MemRefDialect>();
2879   }
2880 
runOnFunctionmlir::__anonb93b5a210111::LhloLegalizeToLinalgPass2881   void runOnFunction() override {
2882     OwningRewritePatternList patterns(&getContext());
2883     ConversionTarget target(getContext());
2884     target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
2885                            math::MathDialect, memref::MemRefDialect,
2886                            StandardOpsDialect, AffineDialect>();
2887     target.addLegalOp<UnrealizedConversionCastOp>();
2888 
2889     mhlo::RemoveSignTypeConverter type_converter;
2890     auto func = getFunction();
2891     populateLHLOToLinalgConversionPattern(func.getContext(), type_converter,
2892                                           &patterns);
2893     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
2894       signalPassFailure();
2895     }
2896   }
2897 };
2898 
2899 struct HloLegalizeToLinalgPass
2900     : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
getDependentDialectsmlir::__anonb93b5a210111::HloLegalizeToLinalgPass2901   void getDependentDialects(DialectRegistry& registry) const override {
2902     registry.insert<linalg::LinalgDialect, scf::SCFDialect,
2903                     complex::ComplexDialect, math::MathDialect,
2904                     memref::MemRefDialect, shape::ShapeDialect>();
2905   }
2906 
runOnFunctionmlir::__anonb93b5a210111::HloLegalizeToLinalgPass2907   void runOnFunction() override {
2908     MLIRContext& ctx = getContext();
2909     OwningRewritePatternList patterns(&ctx);
2910     ConversionTarget target(ctx);
2911     target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
2912                            math::MathDialect, StandardOpsDialect,
2913                            tensor::TensorDialect, scf::SCFDialect,
2914                            shape::ShapeDialect>();
2915 
2916     target.addLegalOp<UnrealizedConversionCastOp>();
2917 
2918     mhlo::RemoveSignTypeConverter type_converter;
2919     auto func = getFunction();
2920     mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns);
2921     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
2922       signalPassFailure();
2923     }
2924   }
2925 };
2926 
2927 }  // namespace
2928 
2929 namespace lmhlo {
createLegalizeLhloToLinalgPass()2930 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
2931   return std::make_unique<LhloLegalizeToLinalgPass>();
2932 }
2933 }  // namespace lmhlo
2934 
2935 namespace mhlo {
2936 
populateHLOToLinalgConversionPattern(MLIRContext * context,TypeConverter & type_converter,OwningRewritePatternList * patterns)2937 void populateHLOToLinalgConversionPattern(MLIRContext* context,
2938                                           TypeConverter& type_converter,
2939                                           OwningRewritePatternList* patterns) {
2940   // clang-format off
2941   patterns->insert<
2942       BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
2943       ConstConverterTensor, HloDynamicBroadcastInDimConverter,
2944       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
2945       EinsumToLinalgConverter,
2946       IotaConverter<mhlo::DynamicIotaOp, false>,
2947       PointwiseToLinalgConverter<mhlo::AbsOp, false>,
2948       PointwiseToLinalgConverter<mhlo::AddOp, false>,
2949       PointwiseToLinalgConverter<mhlo::AndOp, false>,
2950       PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
2951       PointwiseToLinalgConverter<mhlo::CeilOp, false>,
2952       PointwiseToLinalgConverter<mhlo::ClampOp, false>,
2953       PointwiseToLinalgConverter<mhlo::CompareOp, false>,
2954       PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
2955       PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
2956       PointwiseToLinalgConverter<mhlo::CopyOp, false>,
2957       PointwiseToLinalgConverter<mhlo::CosOp, false>,
2958       PointwiseToLinalgConverter<mhlo::DivOp, false>,
2959       PointwiseToLinalgConverter<mhlo::ExpOp, false>,
2960       PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
2961       PointwiseToLinalgConverter<mhlo::FloorOp, false>,
2962       PointwiseToLinalgConverter<mhlo::ImagOp, false>,
2963       PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
2964       PointwiseToLinalgConverter<mhlo::LogOp, false>,
2965       PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
2966       PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
2967       PointwiseToLinalgConverter<mhlo::MaxOp, false>,
2968       PointwiseToLinalgConverter<mhlo::MinOp, false>,
2969       PointwiseToLinalgConverter<mhlo::MulOp, false>,
2970       PointwiseToLinalgConverter<mhlo::NegOp, false>,
2971       PointwiseToLinalgConverter<mhlo::NotOp, false>,
2972       PointwiseToLinalgConverter<mhlo::OrOp, false>,
2973       PointwiseToLinalgConverter<mhlo::PowOp, false>,
2974       PointwiseToLinalgConverter<mhlo::RealOp, false>,
2975       PointwiseToLinalgConverter<mhlo::RemOp, false>,
2976       PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
2977       PointwiseToLinalgConverter<mhlo::SelectOp, false>,
2978       PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
2979       PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
2980       PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
2981       PointwiseToLinalgConverter<mhlo::SignOp, false>,
2982       PointwiseToLinalgConverter<mhlo::SinOp, false>,
2983       PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
2984       PointwiseToLinalgConverter<mhlo::SubOp, false>,
2985       PointwiseToLinalgConverter<mhlo::TanhOp, false>,
2986       PointwiseToLinalgConverter<mhlo::XorOp, false>,
2987       ReshapeOpConverter<mhlo::ReshapeOp, false>,
2988       ReverseConverter<mhlo::ReverseOp, false>,
2989       SliceConverter<mhlo::SliceOp, false>,
2990       ComputeReshapeShapeConversion,
2991       CstrReshapableConversion,
2992       DynamicSliceConverter,
2993       DynamicUpdateSliceConverter,
2994       TransposeConverter<mhlo::TransposeOp, false>,
2995       DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
2996                                linalg::MatmulOp>,
2997       DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
2998                                linalg::MatvecOp>,
2999       DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
3000       DotGeneralOpOnTensorsConversion,
3001       NormalConvOpOnTensorsConversion,
3002       DepthwiseConvOpOnTensorsConversion,
3003       ReduceOnTensorsConversion,
3004       ReduceWindowOpOnTensorsConversion,
3005       ScatterUpdateOnTensorsConversion,
3006       TorchIndexSelectOpOnTensorsConversion,
3007       PadOpOnTensorsConversion>(type_converter, context);
3008   // clang-format on
3009   patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
3010                    ReduceRegionXLAOpConversion<mhlo::AndOp>,
3011                    ReduceRegionXLAOpConversion<mhlo::CompareOp>,
3012                    ReduceRegionXLAOpConversion<mhlo::MaxOp>,
3013                    ReduceRegionXLAOpConversion<mhlo::MinOp>,
3014                    ReduceRegionXLAOpConversion<mhlo::MulOp>,
3015                    ReduceRegionXLAOpConversion<mhlo::OrOp>,
3016                    ReduceRegionXLAOpConversion<mhlo::SelectOp>,
3017                    ReduceRegionReturnOpConversion>(context,
3018                                                    PatternBenefit(1000));
3019 }
3020 
createLegalizeHloToLinalgPass()3021 std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
3022   return std::make_unique<HloLegalizeToLinalgPass>();
3023 }
3024 
createHloToLinalgSignedIntegerConverter()3025 std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() {
3026   return std::make_unique<RemoveSignTypeConverter>();
3027 }
3028 
3029 }  // namespace mhlo
3030 }  // namespace mlir
3031