1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17
18 #include <numeric>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
28 #include "mlir/Dialect/Affine/IR/AffineOps.h"
29 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
30 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
31 #include "mlir/Dialect/Math/IR/Math.h"
32 #include "mlir/Dialect/MemRef/IR/MemRef.h"
33 #include "mlir/Dialect/SCF/SCF.h"
34 #include "mlir/Dialect/Shape/IR/Shape.h"
35 #include "mlir/Dialect/StandardOps/IR/Ops.h"
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"
37 #include "mlir/IR/AffineExpr.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/Builders.h"
40 #include "mlir/IR/BuiltinOps.h"
41 #include "mlir/IR/BuiltinTypes.h"
42 #include "mlir/IR/Location.h"
43 #include "mlir/IR/MLIRContext.h"
44 #include "mlir/IR/Matchers.h"
45 #include "mlir/IR/Operation.h"
46 #include "mlir/IR/OperationSupport.h"
47 #include "mlir/IR/PatternMatch.h"
48 #include "mlir/IR/TypeUtilities.h"
49 #include "mlir/Pass/Pass.h"
50 #include "mlir/Pass/PassManager.h"
51 #include "mlir/Support/LogicalResult.h"
52 #include "mlir/Transforms/DialectConversion.h"
53
54 namespace mlir {
55 namespace {
56
57 /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
58 /// are "parallel" except the last `nReduction` elements, where are "reduction"
59 /// attributes.
GetParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)60 SmallVector<StringRef, 3> GetParallelAndReductionIterators(
61 unsigned nLoops, unsigned nReduction) {
62 SmallVector<StringRef, 3> res(nLoops - nReduction,
63 getParallelIteratorTypeName());
64 res.append(nReduction, getReductionIteratorTypeName());
65 return res;
66 }
67
GetNParallelLoopsAttrs(unsigned nParallelLoops)68 SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
69 return GetParallelAndReductionIterators(nParallelLoops, 0);
70 }
71
72 template <bool isLHLO = true>
GetResultValue(Operation * op)73 Value GetResultValue(Operation* op) {
74 return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
75 }
76
77 template <bool isLHLO = true>
GetHloOpResultType(Operation * op)78 ShapedType GetHloOpResultType(Operation* op) {
79 return GetResultValue<isLHLO>(op).getType().template cast<ShapedType>();
80 }
81
82 template <bool isLHLO = true>
VerifyHloOpBufferOrTensorSemantics(Operation * op)83 bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
84 auto verify_type = [&](Value val) -> bool {
85 return (isLHLO && val.getType().isa<MemRefType>()) ||
86 (!isLHLO && val.getType().isa<RankedTensorType>());
87 };
88 if (!llvm::all_of(op->getOperands(), verify_type)) return false;
89 return isLHLO ? op->getResults().empty()
90 : llvm::all_of(op->getResults(), verify_type);
91 }
92
GetInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dyn_sizes)93 Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
94 ArrayRef<Value> dyn_sizes) {
95 return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
96 type.getElementType());
97 }
98
ExtractDynamicSizes(OpBuilder & b,Location loc,Value tensor,Value shape_tensor=nullptr,AffineMap permutation={})99 SmallVector<Value, 2> ExtractDynamicSizes(OpBuilder& b, Location loc,
100 Value tensor,
101 Value shape_tensor = nullptr,
102 AffineMap permutation = {}) {
103 auto tensor_type = tensor.getType().dyn_cast<RankedTensorType>();
104 if (!tensor_type) return {};
105 SmallVector<Value, 2> dyn_sizes(tensor_type.getRank());
106 for (auto& en : llvm::enumerate(tensor_type.getShape())) {
107 if (en.value() != ShapedType::kDynamicSize) continue;
108 // If a shape tensor is present extract from there.
109 if (shape_tensor) {
110 Value extract = b.create<tensor::ExtractOp>(
111 loc, shape_tensor,
112 ValueRange{b.create<ConstantIndexOp>(loc, en.index())});
113 dyn_sizes[en.index()] =
114 b.create<IndexCastOp>(loc, b.getIndexType(), extract);
115 } else {
116 dyn_sizes[en.index()] = b.create<tensor::DimOp>(loc, tensor, en.index());
117 }
118 }
119 if (permutation)
120 dyn_sizes = applyPermutationMap(permutation, makeArrayRef(dyn_sizes));
121 llvm::erase_value(dyn_sizes, nullptr); // Strip out placeholders.
122 return dyn_sizes;
123 }
124
Extract1DVector(DenseIntElementsAttr elements)125 SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
126 SmallVector<int64_t, 4> ret;
127 for (const APInt& element : elements) {
128 ret.push_back(element.getLimitedValue());
129 }
130 return ret;
131 }
132
133 /// Returns the constant value associated with the init value if the defining
134 /// operation is a constant.
GetInitValueAsConst(Value init)135 Attribute GetInitValueAsConst(Value init) {
136 DenseElementsAttr attr;
137 if (!matchPattern(init, m_Constant(&attr))) return {};
138 auto type = attr.getType().dyn_cast<ShapedType>();
139 if (!type || type.getRank() != 0) return {};
140 return attr.getValue({});
141 }
142
143 /// Returns a permutation AffineMap that puts all reduction dimensions to the
144 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
145 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
146 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
147 /// the AffineMap is returned.
GetTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reduction_dims)148 AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
149 ArrayRef<int64_t> reduction_dims) {
150 llvm::SmallSetVector<int, 4> s;
151 for (auto dim : reduction_dims) s.insert(dim);
152
153 SmallVector<unsigned, 4> permutation;
154 for (int i = 0; i < rank; ++i)
155 if (!s.count(i)) permutation.push_back(i);
156 for (auto dim : reduction_dims) permutation.push_back(dim);
157
158 auto map = AffineMap::getPermutationMap(permutation, context);
159 return inversePermutation(map);
160 }
161
162 /// Returns true if the given `attr` is a splat of the given `value`.
isSplatValue(DenseIntElementsAttr attr,uint64_t value)163 bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
164 return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
165 }
166
167 /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
168 /// follows a canonical form:
169 ///
170 /// * Input dimensions have order: (batch_count, spatial_dims,
171 /// input_channel_count).
172 /// * Filter dimensions have order: (spatial_dims, input_channel_count,
173 /// output_channel_count).
174 /// * Output dimensions have order: (batch_count, spatial_dims,
175 /// output_channel_count).
176 template <typename DimensionNumbersTy>
HasCanonicalDimensionNumbers(const DimensionNumbersTy & dimension_numbers)177 static bool HasCanonicalDimensionNumbers(
178 const DimensionNumbersTy& dimension_numbers) {
179 const int input_spatial_rank =
180 llvm::size(dimension_numbers.input_spatial_dimensions());
181 // The dimensions for input should follow the order of
182 // batch_count, spatial_dims..., input_feature_count.
183 if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
184 dimension_numbers.input_feature_dimension().getInt() !=
185 (input_spatial_rank + 1)) {
186 return false;
187 }
188
189 const int kernel_spatial_rank =
190 llvm::size(dimension_numbers.kernel_spatial_dimensions());
191 // The dimensions for filter should follow the order of
192 // spatial_dims..., input_feature_count, num_output_feature_count.
193 if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
194 kernel_spatial_rank ||
195 dimension_numbers.kernel_output_feature_dimension().getInt() !=
196 (kernel_spatial_rank + 1)) {
197 return false;
198 }
199
200 const int output_spatial_rank =
201 llvm::size(dimension_numbers.output_spatial_dimensions());
202 // The dimensions for output should follow the order of
203 // batch_count, spatial_dims.., output_feature_count.
204 if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
205 dimension_numbers.output_feature_dimension().getInt() !=
206 (output_spatial_rank + 1)) {
207 return false;
208 }
209
210 if (input_spatial_rank != output_spatial_rank ||
211 input_spatial_rank != kernel_spatial_rank) {
212 return false;
213 }
214
215 auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin();
216 auto kernel_spatial_dim =
217 dimension_numbers.kernel_spatial_dimensions().begin();
218 auto output_spatial_dim =
219 dimension_numbers.output_spatial_dimensions().begin();
220 // Check spatial dims are ordered correctly.
221 for (int i = 0; i < input_spatial_rank; ++i) {
222 const int dim = i + 1;
223 if ((*input_spatial_dim++).getZExtValue() != dim ||
224 (*output_spatial_dim++).getZExtValue() != dim ||
225 (*kernel_spatial_dim++).getZExtValue() != i) {
226 return false;
227 }
228 }
229
230 return true;
231 }
232
233 //===----------------------------------------------------------------------===//
234 // mhlo.Einsum conversion patterns.
235 //===----------------------------------------------------------------------===//
236
237 // Looks through a set of dimension that has been marked as reduction axes,
238 // if it is found within the set, then we set it as "reduction", otherwise
239 // we can label it as "parallel".
GetEinsumLoopsAttrs(const llvm::SmallSetVector<StringRef,4> & input_ind,const llvm::SmallSetVector<StringRef,4> & reduction_dims)240 SmallVector<StringRef, 3> GetEinsumLoopsAttrs(
241 const llvm::SmallSetVector<StringRef, 4>& input_ind,
242 const llvm::SmallSetVector<StringRef, 4>& reduction_dims) {
243 SmallVector<StringRef, 3> res;
244 for (StringRef dim : input_ind) {
245 if (!reduction_dims.contains(dim)) {
246 res.push_back(getParallelIteratorTypeName());
247 } else {
248 res.push_back(getReductionIteratorTypeName());
249 }
250 }
251 return res;
252 }
253
ExtractDynamicEinsumSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,const SmallVector<std::string> & lhs_loop_vec,const SmallVector<std::string> & rhs_loop_vec,const SmallVector<std::string> & output_loop_vec)254 SmallVector<Value, 2> ExtractDynamicEinsumSizes(
255 OpBuilder& b, Location loc, Value lhs, Value rhs,
256 const SmallVector<std::string>& lhs_loop_vec,
257 const SmallVector<std::string>& rhs_loop_vec,
258 const SmallVector<std::string>& output_loop_vec) {
259 SmallVector<Value, 2> dyn_sizes;
260 for (const std::string& dim_ind : output_loop_vec) {
261 Value dim_size;
262 auto dim_ind_it =
263 std::find(lhs_loop_vec.begin(), lhs_loop_vec.end(), dim_ind);
264 if (dim_ind_it != lhs_loop_vec.end()) {
265 // Query from lhs vars.
266 auto dim_ind_pos = dim_ind_it - lhs_loop_vec.begin();
267 auto lhs_shape = lhs.getType().dyn_cast<RankedTensorType>().getShape();
268 if (lhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
269 dim_size = b.create<tensor::DimOp>(loc, lhs, dim_ind_pos);
270 } else {
271 // query from rhs vars.
272 dim_ind_it = std::find(rhs_loop_vec.begin(), rhs_loop_vec.end(), dim_ind);
273 auto dim_ind_pos = dim_ind_it - rhs_loop_vec.begin();
274 auto rhs_shape = rhs.getType().dyn_cast<RankedTensorType>().getShape();
275 if (rhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
276 dim_size = b.create<tensor::DimOp>(loc, rhs, dim_ind_pos);
277 }
278 dyn_sizes.push_back(dim_size);
279 }
280 return dyn_sizes;
281 }
282
283 // Adds indices/axes that are missing from output set.
FindSummationAxes(const llvm::SmallSetVector<StringRef,4> & input_set,const llvm::SmallSetVector<StringRef,4> & output_set)284 llvm::SmallSetVector<StringRef, 4> FindSummationAxes(
285 const llvm::SmallSetVector<StringRef, 4>& input_set,
286 const llvm::SmallSetVector<StringRef, 4>& output_set) {
287 llvm::SmallSetVector<StringRef, 4> summation_axes;
288 for (StringRef ind : input_set) {
289 if (!output_set.contains(ind)) summation_axes.insert(ind);
290 }
291 return summation_axes;
292 }
293
294 // Given a 1:1 map from std::string -> affine dimension expression
295 // we can get the affine expression of dimensions that an
296 // operand will access based on the input_str of einsum_config.
297 // For example:
298 // let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2}
299 // for einsum_config "abc,cb->acb"
300 // first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2).
301 // second_input_operand will get umap[{"c","b"}] -> (d2, d1).
302 // ouput_operand will get umap[{"a","c","b"}] -> (d0, d2, d1).
GetExprFromConfig(const SmallVector<std::string> & loop_dims,const DenseMap<StringRef,AffineExpr> & str_affine_dim_umap)303 SmallVector<AffineExpr> GetExprFromConfig(
304 const SmallVector<std::string>& loop_dims,
305 const DenseMap<StringRef, AffineExpr>& str_affine_dim_umap) {
306 SmallVector<AffineExpr> exprs;
307 for (const auto& dim : loop_dims) {
308 exprs.push_back(str_affine_dim_umap.lookup(dim));
309 }
310 return exprs;
311 }
312
313 // Convert mhlo.einsum op into linalg.generic.
314 // Algorithm in general 3 steps:
315
316 // Step1) Dissect entire einsum_config to different operands
317 // e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}.
318
319 // Step2) Split up the string into vector of the elements
320 // e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"],
321 // rhs:["c","d"], out:["a","b","d"]}.
322
323 // Step3) Convert the vector into data access
324 // patern represented by affineMaps with affineDimensions e.g
325 // {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2],
326 // rhs:[d2,d3], out:[d0,d1,d3]}.
327 class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> {
328 public:
329 using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern;
330
matchAndRewrite(mhlo::EinsumOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const331 LogicalResult matchAndRewrite(
332 mhlo::EinsumOp op, ArrayRef<Value> args,
333 ConversionPatternRewriter& rewriter) const final {
334 mhlo::EinsumOp::Adaptor adaptor(args);
335 auto get_rank = [](Value v) {
336 return v.getType().cast<ShapedType>().getRank();
337 };
338 auto einsum_config = op.einsum_config();
339
340 // With the assumption of binary input operand and single output
341 // get the inputs and output operands' indices.
342 // einsum_config = "lhs_loop,rhs_loop->out_loop"
343 std::size_t pos_arrow = einsum_config.find(kArrow);
344 std::size_t pos_comma = einsum_config.find(kComma);
345
346 StringRef lhs_loop = einsum_config.substr(0, pos_comma);
347 StringRef rhs_loop = einsum_config.substr(
348 pos_comma + kComma.size(), pos_arrow - (pos_comma + kComma.size()));
349 StringRef out_loop = einsum_config.substr(pos_arrow + kArrow.size());
350
351 // Check for Invalid Configs.
352 // 1.Check that there is only maximum 2 inputs
353 // 2.Check that there is only maximum 1 output
354 // 3.Check that there is 1 kArrow
355 if (rhs_loop.find(kComma) != std::string::npos ||
356 out_loop.find(kComma) != std::string::npos ||
357 out_loop.find(kArrow) != std::string::npos) {
358 return rewriter.notifyMatchFailure(op, "Invalid einsum config!");
359 }
360
361 // Find result type, if on tensors.
362 auto result_ty = this->typeConverter
363 ->convertType(GetHloOpResultType</*isLHLO=*/false>(op))
364 .dyn_cast<RankedTensorType>();
365
366 // Check result type compatibility.
367 if (!result_ty || !(result_ty.getElementType().isSignlessIntOrFloat())) {
368 return rewriter.notifyMatchFailure(op, "Invalid result type");
369 }
370
371 // Convert the representation to vector<string>.
372 SmallVector<std::string> lhs_ein =
373 GetEinsumConfigAsVector(lhs_loop, get_rank(adaptor.lhs()));
374 SmallVector<std::string> rhs_ein =
375 GetEinsumConfigAsVector(rhs_loop, get_rank(adaptor.rhs()));
376 SmallVector<std::string> out_ein =
377 GetEinsumConfigAsVector(out_loop, result_ty.getRank());
378
379 if (!CheckBatchHasEqualRank(lhs_ein.size(), lhs_loop, rhs_ein.size(),
380 rhs_loop, out_ein.size(), out_loop)) {
381 return rewriter.notifyMatchFailure(
382 op, "Invalid elipsis('...') within einsum config!");
383 }
384
385 // Find all unique indices in the input and output.
386 llvm::SmallSetVector<StringRef, 4> input_ind;
387 llvm::SmallSetVector<StringRef, 4> output_ind;
388
389 input_ind.insert(lhs_ein.begin(), lhs_ein.end());
390 input_ind.insert(rhs_ein.begin(), rhs_ein.end());
391 output_ind.insert(out_ein.begin(), out_ein.end());
392
393 llvm::SmallSetVector<StringRef, 4> reduction_axe =
394 FindSummationAxes(input_ind, output_ind);
395
396 // Find input/output values and types.
397 auto loc = op.getLoc();
398
399 // Prepare init tensor for linalg.generic op.
400 auto dyn_sizes = ExtractDynamicEinsumSizes(
401 rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhs_ein, rhs_ein, out_ein);
402 Value output = GetInitTensor(rewriter, loc, result_ty, dyn_sizes);
403 if (!reduction_axe.empty()) {
404 auto zero_attr = rewriter.getZeroAttr(result_ty.getElementType());
405 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
406 output = rewriter.create<linalg::FillOp>(loc, zero, output).getResult(0);
407 }
408
409 // Create indexing maps.
410 // Create a 1:1 map from f:strDimension -> affineDimension.
411 int64_t nloops = input_ind.size();
412 DenseMap<StringRef, AffineExpr> str_affine_dim_umap;
413 for (auto it : llvm::enumerate(input_ind)) {
414 str_affine_dim_umap[it.value()] = rewriter.getAffineDimExpr(it.index());
415 }
416
417 // From einsum_config of each operand in vector<string>, generate
418 // the equivalent vector<AffineExpr>.
419 SmallVector<AffineMap, 4> maps;
420 for (const SmallVector<std::string>& loop_operand :
421 {lhs_ein, rhs_ein, out_ein}) {
422 auto exprs = GetExprFromConfig(loop_operand, str_affine_dim_umap);
423 maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext()));
424 }
425
426 auto linalg_op = rewriter.create<linalg::GenericOp>(
427 loc, result_ty ? result_ty : TypeRange{}, args, output, maps,
428 GetEinsumLoopsAttrs(input_ind, reduction_axe),
429 [&](OpBuilder& b, Location nested_loc, ValueRange args) {
430 Value result_val =
431 b.create<mlir::MulFOp>(nested_loc, args[0], args[1]);
432 if (!reduction_axe.empty()) {
433 result_val =
434 b.create<mlir::AddFOp>(nested_loc, args[2], result_val);
435 }
436 b.create<linalg::YieldOp>(nested_loc, result_val);
437 });
438 rewriter.replaceOp(op, linalg_op.getResults());
439 return success();
440 }
441
442 private:
443 static constexpr StringRef kArrow = "->";
444 static constexpr StringRef kComma = ",";
445 static constexpr StringRef kEllipsis = "...";
446
447 static bool CheckBatchHasEqualRank(size_t lhs_rank, StringRef lhs_loop,
448 size_t rhs_rank, StringRef rhs_loop,
449 size_t out_rank, StringRef out_loop);
450 static SmallVector<std::string> GetEinsumConfigAsVector(StringRef loop,
451 size_t operand_rank);
452 };
453
454 // Definition of util const member variables.
455 constexpr StringRef EinsumToLinalgConverter::kArrow;
456 constexpr StringRef EinsumToLinalgConverter::kComma;
457 constexpr StringRef EinsumToLinalgConverter::kEllipsis;
458
459 // Convert the representation from string/vector<char> to vector<string>.
460 // i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3:
461 // get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"}
GetEinsumConfigAsVector(StringRef loop,size_t operand_rank)462 SmallVector<std::string> EinsumToLinalgConverter::GetEinsumConfigAsVector(
463 StringRef loop, size_t operand_rank) {
464 SmallVector<std::string> loop_dim;
465 size_t pre_elip = loop.find(kEllipsis);
466 bool has_elip = pre_elip != std::string::npos;
467 if (!has_elip) pre_elip = loop.size();
468 // Add the dimension until the end or up to ellipsis if it exist.
469 for (int pre_elip_ind = 0; pre_elip_ind < pre_elip; pre_elip_ind++) {
470 loop_dim.push_back(loop.substr(pre_elip_ind, 1).str());
471 }
472 if (!has_elip) return loop_dim;
473 // Case where Ellipsis presence:
474 size_t non_batch_rank = loop.size() - kEllipsis.size();
475 size_t batch_rank = operand_rank - non_batch_rank;
476 // Add the batch dimension ("0",...,"N") where N is rank of batch into the
477 // loop.
478 for (int batch_ind = 0; batch_ind < batch_rank; batch_ind++) {
479 loop_dim.push_back(std::to_string(batch_ind));
480 }
481 // Add the dimension after ellipsis into the loop.
482 int post_elip = pre_elip + kEllipsis.size();
483 for (int post_elip_ind = post_elip; post_elip_ind < loop.size();
484 post_elip_ind++) {
485 loop_dim.push_back(loop.substr(post_elip_ind, 1).str());
486 }
487 return loop_dim;
488 }
489
490 // Returns true if all operand's batch has same rank.
CheckBatchHasEqualRank(size_t lhs_rank,StringRef lhs_loop,size_t rhs_rank,StringRef rhs_loop,size_t out_rank,StringRef out_loop)491 bool EinsumToLinalgConverter::CheckBatchHasEqualRank(
492 size_t lhs_rank, StringRef lhs_loop, size_t rhs_rank, StringRef rhs_loop,
493 size_t out_rank, StringRef out_loop) {
494 SmallVector<int, 3> batch_rank_vec;
495 if (lhs_rank != lhs_loop.size()) {
496 size_t lhs_batch_rank = lhs_rank - (lhs_loop.size() - kEllipsis.size());
497 batch_rank_vec.push_back(lhs_batch_rank);
498 }
499 if (rhs_rank != rhs_loop.size()) {
500 size_t rhs_batch_rank = rhs_rank - (rhs_loop.size() - kEllipsis.size());
501 batch_rank_vec.push_back(rhs_batch_rank);
502 }
503 if (out_rank != out_loop.size()) {
504 size_t out_batch_rank = out_rank - (out_loop.size() - kEllipsis.size());
505 batch_rank_vec.push_back(out_batch_rank);
506 }
507 bool batch_has_equal_rank = true;
508
509 // Condition is valid if only 1 operand or less have batches.
510 if (batch_rank_vec.size() < 2) return batch_has_equal_rank;
511 if (!std::equal(batch_rank_vec.begin() + 1, batch_rank_vec.end(),
512 batch_rank_vec.begin()) &&
513 batch_rank_vec.size() > 1)
514 batch_has_equal_rank = false;
515 return batch_has_equal_rank;
516 }
517
518 template <typename OpTy, bool isLHLO = true>
519 class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
520 public:
521 using OpConversionPattern<OpTy>::OpConversionPattern;
522
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const523 LogicalResult matchAndRewrite(
524 OpTy op, ArrayRef<Value> args,
525 ConversionPatternRewriter& rewriter) const final {
526 // Find maximum rank / number of loops.
527 auto get_rank = [](Value v) {
528 return v.getType().cast<ShapedType>().getRank();
529 };
530 auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
531 auto it = llvm::find_if_not(args, is_scalar);
532 Value max_rank_arg = it != args.end() ? *it : args.front();
533 int64_t nloops = get_rank(max_rank_arg);
534
535 if (isLHLO && nloops == 0) return failure();
536
537 // Apply only if all operands are scalar or have the same rank. Some ops,
538 // like `mhlo.select`, support implicit broadcasting of scalars.
539 if (!llvm::all_of(args, [&](Value v) {
540 int64_t r = get_rank(v);
541 return r == 0 || r == nloops;
542 })) {
543 return rewriter.notifyMatchFailure(
544 op, "Operands must be os same rank or scalar.");
545 }
546
547 // Find result type, if on tensors.
548 Optional<ShapedType> result_ty;
549 if (!isLHLO) {
550 result_ty = this->typeConverter->convertType(op->getResultTypes().front())
551 .template dyn_cast<ShapedType>();
552
553 // Check result type compatibility.
554 if (!result_ty || !result_ty->hasRank() ||
555 result_ty->getRank() != nloops ||
556 !(result_ty->getElementType().isSignlessIntOrFloat() ||
557 result_ty->getElementType().isa<ComplexType>())) {
558 return rewriter.notifyMatchFailure(
559 op, "mismatched operand/result types or iterator count");
560 }
561 }
562
563 // Find input/output values and types.
564 auto loc = op.getLoc();
565 ValueRange inputs = isLHLO ? args.drop_back() : args;
566 Value output;
567 if (isLHLO) {
568 output = args.back();
569 } else {
570 auto dyn_sizes = ExtractDynamicSizes(rewriter, loc, max_rank_arg);
571 output = GetInitTensor(rewriter, loc, *result_ty, dyn_sizes);
572 }
573
574 // Create indexing maps.
575 AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
576 AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
577 SmallVector<AffineMap, 4> maps;
578 for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
579 maps.push_back(id_map);
580
581 // Build `linalg.generic` op.
582 bool failed = false;
583 auto linalg_op = rewriter.create<linalg::GenericOp>(
584 loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
585 GetNParallelLoopsAttrs(nloops),
586 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
587 // TODO(ravishankarm) : For now use the method in lmhlo namespace.
588 // That method needs to be moved out of there.
589 Type inner_result_ty = getElementTypeOrSelf(output);
590 Value inner_result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
591 op, inner_result_ty,
592 llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
593 if (inner_result == nullptr) {
594 failed = true;
595 } else {
596 nested_builder.create<linalg::YieldOp>(loc, inner_result);
597 }
598 });
599 if (failed) return failure();
600 rewriter.replaceOp(op, linalg_op->getResults());
601 return success();
602 }
603 };
604
605 template <typename LhloOp>
606 class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
607 public:
608 using OpConversionPattern<LhloOp>::OpConversionPattern;
609
matchAndRewrite(LhloOp lhlo_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const610 LogicalResult matchAndRewrite(
611 LhloOp lhlo_op, ArrayRef<Value> args,
612 ConversionPatternRewriter& rewriter) const final {
613 auto loc = lhlo_op.getLoc();
614 auto arg_type =
615 lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
616 if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
617 (arg_type.getRank() != 0)) {
618 return failure();
619 }
620
621 // Create two loads from the input.
622 auto lhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.lhs());
623 auto rhs = rewriter.create<memref::LoadOp>(loc, lhlo_op.rhs());
624 // TODO(ravishankarm) : Move this method out of lmhlo namespace.
625 Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOp>(
626 lhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
627 &rewriter);
628 rewriter.create<memref::StoreOp>(loc, op_result, lhlo_op.out());
629 rewriter.eraseOp(lhlo_op);
630 return success();
631 }
632 };
633
634 //===----------------------------------------------------------------------===//
635 // lmhlo.convolution conversion pattern.
636 //===----------------------------------------------------------------------===//
637
638 /// Converts lmhlo.convolution operation to a linalg.conv op.
639 struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
640 public:
641 using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
642
matchAndRewritemlir::__anonb93b5a210111::ConvToLinalgConverter643 LogicalResult matchAndRewrite(
644 lmhlo::ConvOp op, ArrayRef<Value> args,
645 ConversionPatternRewriter& rewriter) const final {
646 if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
647
648 // TODO: LHS dilation for deconvolution not supported yet.
649 // TODO(jurahul): Window reversal is not supported yet.
650 if (op.lhs_dilation() || op.hasWindowReversal()) {
651 return failure();
652 }
653
654 llvm::SmallVector<Attribute, 4> strides;
655 if (auto window_strides = op.window_strides()) {
656 auto range = window_strides->getAttributeValues();
657 strides.assign(range.begin(), range.end());
658 }
659 auto strides_arg = ArrayAttr::get(op.getContext(), strides);
660
661 llvm::SmallVector<Attribute, 2> dilation;
662 if (auto rhs_dilation = op.rhs_dilation()) {
663 auto range = rhs_dilation->getAttributeValues();
664 dilation.assign(range.begin(), range.end());
665 } else {
666 // Default dilation of 1.
667 dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
668 }
669 auto dilation_arg = ArrayAttr::get(op.getContext(), dilation);
670
671 // Set padding only if it is non-zero.
672 DenseIntElementsAttr padding = op.paddingAttr();
673 if (!padding ||
674 !llvm::any_of(padding.getValues<APInt>(),
675 [](APInt int_val) { return !int_val.isNullValue(); })) {
676 padding = nullptr;
677 }
678
679 // The order of input and filter are switched with linalg.conv.
680 rewriter.replaceOpWithNewOp<linalg::ConvOp>(
681 op, args[1], args[0], args[2], strides_arg, dilation_arg, padding);
682 return success();
683 }
684 };
685
686 /// Base class for lowering HLO operations that have one operand and one result,
687 /// and are semantically equivalent to a copy of the input to the output (like
688 /// transpose, some reshape, etc.). The derived classes need to provide a method
689 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
690 /// and the output.
691 template <typename Derived, typename OpTy, bool isLHLO = true>
692 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
693 public:
694 using OpConversionPattern<OpTy>::OpConversionPattern;
695
matchAndRewrite(OpTy op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const696 LogicalResult matchAndRewrite(
697 OpTy op, ArrayRef<Value> args,
698 ConversionPatternRewriter& rewriter) const final {
699 if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
700 auto result_type = GetHloOpResultType<isLHLO>(op);
701 result_type = this->typeConverter->convertType(result_type)
702 .template cast<ShapedType>();
703
704 SmallVector<AffineMap, 2> indexing_maps =
705 Derived::getIndexingMaps(op, &rewriter);
706 if (indexing_maps.empty()) return failure();
707
708 auto nloops = result_type.getRank();
709 auto loc = op.getLoc();
710 AffineMap shape_permutation =
711 indexing_maps[0].isPermutation() ? indexing_maps[0] : AffineMap();
712 // TODO(pifon): technically, the op itself could have size operands (e.g.
713 // broadcast into a dynamic dimension).Handle this case.
714 auto dyn_sizes = isLHLO ? SmallVector<Value, 2>()
715 : ExtractDynamicSizes(rewriter, loc, args[0],
716 nullptr, shape_permutation);
717 auto linalg_op = rewriter.create<linalg::GenericOp>(
718 loc,
719 /*resultTensorTypes=*/isLHLO ? ArrayRef<Type>{} : result_type,
720 /*inputs=*/args.front(),
721 /*outputBuffers=*/
722 isLHLO
723 ? ValueRange{args.back()}
724 : ValueRange{GetInitTensor(rewriter, loc, result_type, dyn_sizes)},
725 indexing_maps, GetNParallelLoopsAttrs(nloops),
726 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
727 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
728 });
729 rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
730 return success();
731 }
732 };
733
734 /// Pattern to convert BroadcastOp to Linalg ops.
735 template <typename OpTy, bool isLHLO = true>
736 class BroadcastConverter
737 : public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
738 isLHLO> {
739 public:
740 using DataMovementOpConverter<BroadcastConverter, OpTy,
741 isLHLO>::DataMovementOpConverter;
742
getIndexingMaps(OpTy broadcast_op,Builder * b)743 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
744 Builder* b) {
745 ShapedType input_type =
746 broadcast_op.operand().getType().template cast<ShapedType>();
747 unsigned input_rank = input_type.getRank();
748 unsigned nloops = GetHloOpResultType<isLHLO>(broadcast_op).getRank();
749
750 // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
751 // the input's dimensions.
752 unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
753 SmallVector<AffineExpr, 4> input_dim_exprs;
754 input_dim_exprs.reserve(input_rank);
755 for (unsigned i = 0; i < input_rank; ++i) {
756 input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
757 }
758
759 AffineMap input_map;
760 MLIRContext* context = b->getContext();
761 if (input_dim_exprs.empty()) {
762 // The input is a scalar, i.e. this is a scalar broadcast op.
763 input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
764 } else {
765 input_map =
766 AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
767 }
768 return {input_map, b->getMultiDimIdentityMap(nloops)};
769 }
770 };
771
772 class HloBroadcastInDimConverter
773 : public DataMovementOpConverter<HloBroadcastInDimConverter,
774 mhlo::BroadcastInDimOp, false> {
775 public:
776 using DataMovementOpConverter<HloBroadcastInDimConverter,
777 mhlo::BroadcastInDimOp,
778 false>::DataMovementOpConverter;
779
getIndexingMaps(mhlo::BroadcastInDimOp broadcast_op,Builder * b)780 static SmallVector<AffineMap, 2> getIndexingMaps(
781 mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
782 auto result_type = GetHloOpResultType<false>(broadcast_op);
783 auto operand_type =
784 broadcast_op.operand().getType().template cast<ShapedType>();
785 unsigned nloops = result_type.getRank();
786
787 // The input is a scalar, i.e. this is a scalar broadcast op.
788 if (operand_type.getRank() == 0) {
789 return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
790 b->getMultiDimIdentityMap(nloops)};
791 }
792
793 auto operand_shape = operand_type.getShape();
794 SmallVector<AffineExpr, 4> dim_exprs;
795 dim_exprs.reserve(nloops);
796
797 if (broadcast_op.broadcast_dimensions()) {
798 for (const auto& broadcastDim :
799 enumerate(broadcast_op.broadcast_dimensions().getIntValues())) {
800 int size = broadcastDim.value().getSExtValue();
801 bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
802 result_type.getShape()[size] != 1;
803 dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
804 : b->getAffineDimExpr(size));
805 }
806 }
807 return {
808 AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
809 b->getMultiDimIdentityMap(nloops)};
810 }
811 };
812
813 class HloDynamicBroadcastInDimConverter
814 : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
815 public:
816 using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
817
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const818 LogicalResult matchAndRewrite(
819 mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
820 ConversionPatternRewriter& rewriter) const final {
821 // If the input has a static shape we know exactly when the broadcast must
822 // expand (the dimension is 1, which also trivially expands to 1) or will
823 // never expand (the dimension is not 1). This means we can lower the
824 // broadcast just as we would lower a fully static broadcast and go directly
825 // to linalg.generic. This also covers the important case of broadcasting a
826 // scalar.
827
828 // Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`)
829 // should be converted to an Tensor-dialect op similar to TF ConstantLikeOp.
830
831 mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(operands);
832 Value operand = adaptor.operand();
833 auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
834 if (!operand_type || !operand_type.hasStaticShape()) return failure();
835
836 Value shape = adaptor.output_dimensions();
837 auto shape_type = shape.getType().cast<RankedTensorType>();
838 int64_t result_rank = shape_type.getDimSize(0);
839 // HLO dimension types can be any integer, as well as index.
840 bool convert_to_index =
841 shape_type.getElementType() != rewriter.getIndexType();
842
843 auto result_type =
844 typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>();
845 if (!result_type) return failure();
846
847 SmallVector<Value, 2> dyn_dims;
848 Location loc = op.getLoc();
849 for (int i = 0; i < result_rank; ++i) {
850 if (!result_type.isDynamicDim(i)) continue;
851 Value index = rewriter.create<ConstantIndexOp>(loc, i);
852 Value dim = rewriter.create<tensor::ExtractOp>(loc, shape, index);
853 if (convert_to_index) {
854 dim = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), dim);
855 }
856 dyn_dims.push_back(dim);
857 }
858
859 int64_t nloops = result_type.getRank();
860 auto operand_shape = operand_type.getShape();
861 SmallVector<AffineExpr, 4> dim_exprs;
862 dim_exprs.reserve(nloops);
863
864 if (op.broadcast_dimensions()) {
865 for (const auto& broadcast_dim :
866 enumerate(op.broadcast_dimensions().getIntValues())) {
867 int64_t size = broadcast_dim.value().getSExtValue();
868 bool expansion_needed = operand_shape[broadcast_dim.index()] == 1;
869 dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0)
870 : rewriter.getAffineDimExpr(size));
871 }
872 }
873
874 Value init = rewriter.create<linalg::InitTensorOp>(
875 loc, dyn_dims, result_type.getShape(), result_type.getElementType());
876 Operation* generic = rewriter.create<linalg::GenericOp>(
877 loc, TypeRange{init.getType()}, ValueRange{operand},
878 /*outputBuffers=*/ValueRange{init},
879 llvm::makeArrayRef(
880 {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
881 rewriter.getContext()),
882 rewriter.getMultiDimIdentityMap(nloops)}),
883 GetNParallelLoopsAttrs(nloops),
884 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
885 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
886 });
887 rewriter.replaceOp(op, generic->getResults());
888 return success();
889 }
890 };
891
892 class LhloBroadcastInDimConverter
893 : public OpConversionPattern<lmhlo::BroadcastInDimOp> {
894 public:
895 using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
896
matchAndRewrite(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const897 LogicalResult matchAndRewrite(
898 lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
899 ConversionPatternRewriter& rewriter) const final {
900 lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
901 auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
902 auto result_shape = result_type.getShape();
903
904 auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
905
906 Value operand = std::get<0>(operand_and_dims);
907 auto broadcast_dims = std::get<1>(operand_and_dims);
908
909 auto loc = op.getLoc();
910 auto nloops = result_type.getRank();
911 auto operand_type = operand.getType().cast<MemRefType>();
912
913 // For a degenerate case, i.e. broadcasting with expansion of
914 // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
915 // Instead the value is loaded and used directly in `linalg.yield`.
916 if (operand_type.getRank() == 1 &&
917 operand_type.getDimSize(0) <
918 result_type.getDimSize(broadcast_dims.front())) {
919 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
920 Value val = rewriter.create<memref::LoadOp>(loc, operand,
921 llvm::makeArrayRef({zero}));
922 rewriter.create<linalg::GenericOp>(
923 loc, /*inputs=*/ValueRange{},
924 /*outputBuffers=*/ValueRange{operand_adaptor.output()},
925 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
926 GetNParallelLoopsAttrs(nloops),
927 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
928 nested_builder.create<linalg::YieldOp>(loc, val);
929 });
930
931 } else {
932 auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
933 operand_type, &rewriter);
934 rewriter.create<linalg::GenericOp>(
935 loc, /*inputs=*/ValueRange{operand},
936 /*outputBuffers=*/ValueRange{operand_adaptor.output()}, indexing_maps,
937 GetNParallelLoopsAttrs(nloops),
938 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
939 nested_builder.create<linalg::YieldOp>(loc, *args.begin());
940 });
941 }
942 rewriter.replaceOp(op, llvm::None);
943 return success();
944 }
945
946 // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
InsertReshapeIfNecessary(lmhlo::BroadcastInDimOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const947 std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
948 lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
949 ConversionPatternRewriter& rewriter) const {
950 lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
951 Value operand = operand_adaptor.operand();
952 auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
953 auto operand_shape = operand_type.getShape();
954
955 Value result = operand_adaptor.output();
956 auto result_type = result.getType().cast<MemRefType>();
957 auto result_shape = result_type.getShape();
958
959 SmallVector<int64_t, 2> operand_strides;
960 int64_t operand_offset;
961 if (failed(getStridesAndOffset(operand_type, operand_strides,
962 operand_offset))) {
963 op.emitOpError() << "Failed to get offset and strides.";
964 }
965
966 SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
967 SmallVector<ReassociationIndices, 4> collapsed_dims_list;
968 ReassociationIndices collapsed_dims;
969 for (const auto& item :
970 enumerate(op.broadcast_dimensions().getIntValues())) {
971 size_t index = item.index();
972 int dim = item.value().getSExtValue();
973
974 collapsed_dims.push_back(index);
975
976 bool expansion_needed =
977 operand_shape[index] == 1 && result_shape[dim] != 1;
978 if (expansion_needed) {
979 continue;
980 }
981 new_shape.push_back(operand_shape[index]);
982 new_strides.push_back(operand_strides[index]);
983 broadcast_dims.push_back(dim);
984
985 collapsed_dims_list.push_back(collapsed_dims);
986 collapsed_dims.clear();
987 }
988 // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
989 // and all dimensions need expansion. Such memref will be reshaped to a 1D
990 // memref with a single element. New shape and strides needs to be updated
991 // accordingly.
992 if (collapsed_dims_list.empty()) {
993 collapsed_dims_list.push_back({});
994 new_shape.push_back(1);
995 new_strides.push_back(1);
996 broadcast_dims.push_back(0);
997 }
998 for (const auto& dims : collapsed_dims) {
999 collapsed_dims_list.back().push_back(dims);
1000 }
1001
1002 // `linalg.collapse_shape` is inserted only if necessary, i.e. when the rank
1003 // can be reduced.
1004 if (new_shape.size() < operand_shape.size()) {
1005 auto new_memref_type = MemRefType::get(
1006 new_shape, operand_type.getElementType(),
1007 makeStridedLinearLayoutMap(new_strides, operand_offset,
1008 rewriter.getContext()));
1009 operand = rewriter.create<memref::CollapseShapeOp>(
1010 op.getLoc(), new_memref_type, operand_adaptor.operand(),
1011 collapsed_dims_list);
1012 }
1013 return std::make_pair(operand, broadcast_dims);
1014 }
1015
getIndexingMaps(lmhlo::BroadcastInDimOp op,ArrayRef<int64_t> broadcast_dims,ArrayRef<int64_t> result_shape,MemRefType operand_type,Builder * b) const1016 SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
1017 ArrayRef<int64_t> broadcast_dims,
1018 ArrayRef<int64_t> result_shape,
1019 MemRefType operand_type,
1020 Builder* b) const {
1021 unsigned nloops = result_shape.size();
1022
1023 // The input is a scalar, i.e. this is a scalar broadcast op.
1024 if (operand_type.getRank() == 0) {
1025 return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
1026 b->getMultiDimIdentityMap(nloops)};
1027 }
1028
1029 auto operand_shape = operand_type.getShape();
1030 SmallVector<AffineExpr, 4> dim_exprs;
1031 dim_exprs.reserve(nloops);
1032
1033 for (const auto& broadcast_dim : llvm::enumerate(broadcast_dims)) {
1034 int size = broadcast_dim.value();
1035 bool expansion_needed =
1036 operand_shape[broadcast_dim.index()] == 1 && result_shape[size] != 1;
1037 if (expansion_needed) {
1038 op.emitOpError(
1039 "BroadcastInDimOp lowering to Linalg does not support size-1 "
1040 "dimensions expansion.");
1041 }
1042 dim_exprs.push_back(b->getAffineDimExpr(size));
1043 }
1044 return {
1045 AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
1046 b->getMultiDimIdentityMap(nloops)};
1047 }
1048 };
1049
1050 template <typename OpTy, bool isLHLO = true>
1051 class TransposeConverter
1052 : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
1053 isLHLO> {
1054 public:
1055 using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
1056 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1057 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1058 auto result_type =
1059 GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1060 auto nloops = result_type.getRank();
1061 SmallVector<AffineExpr, 2> input_exprs;
1062 input_exprs.resize(result_type.getRank());
1063 for (auto permutation : llvm::enumerate(op.permutation())) {
1064 input_exprs[permutation.value().getZExtValue()] =
1065 b->getAffineDimExpr(permutation.index());
1066 }
1067 return {
1068 AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1069 b->getMultiDimIdentityMap(nloops)};
1070 }
1071 };
1072
1073 // Converts reshape ops that can be proven to be either a collapse of dimensions
1074 // or expansion of dimensions of the operand.
1075 template <typename OpTy, bool isLHLO = true>
1076 class ReshapeOpConverter : public OpConversionPattern<OpTy> {
1077 public:
1078 using OpConversionPattern<OpTy>::OpConversionPattern;
1079
matchAndRewrite(OpTy reshape_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1080 LogicalResult matchAndRewrite(
1081 OpTy reshape_op, ArrayRef<Value> args,
1082 ConversionPatternRewriter& rewriter) const final {
1083 if (!VerifyHloOpBufferOrTensorSemantics<isLHLO>(reshape_op))
1084 return failure();
1085 typename OpTy::Adaptor operands(args);
1086 ShapedType operand_type =
1087 operands.operand().getType().template cast<ShapedType>();
1088 ShapedType result_type = GetHloOpResultType<isLHLO>(reshape_op);
1089
1090 if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
1091 return failure();
1092
1093 result_type = this->typeConverter->convertType(result_type)
1094 .template cast<ShapedType>();
1095
1096 // Compute the reassociation maps for the linalg operation.
1097 ArrayRef<int64_t> src_shape =
1098 (operand_type.getRank() > result_type.getRank()
1099 ? operand_type.getShape()
1100 : result_type.getShape());
1101 ArrayRef<int64_t> dst_shape =
1102 (operand_type.getRank() > result_type.getRank()
1103 ? result_type.getShape()
1104 : operand_type.getShape());
1105 unsigned curr_src_dim = 0, curr_dst_dim = 0;
1106 SmallVector<ReassociationExprs, 4> reassociation_map(dst_shape.size());
1107
1108 // First scan all dimensions in the source shapes to see whether we have a
1109 // perfect case where consecutive dimensions in source are collapsed. For
1110 // such case we can just generate one single linalg.reshape.
1111 bool is_collapsing_source = true;
1112 while (curr_src_dim < src_shape.size() && curr_dst_dim < dst_shape.size()) {
1113 int64_t dst_size = dst_shape[curr_dst_dim];
1114 int64_t src_size = src_shape[curr_src_dim];
1115 while (src_size < dst_size && curr_src_dim < src_shape.size()) {
1116 reassociation_map[curr_dst_dim].push_back(
1117 rewriter.getAffineDimExpr(curr_src_dim++));
1118 src_size *= src_shape[curr_src_dim];
1119 }
1120 if (src_size == dst_size) {
1121 reassociation_map[curr_dst_dim].push_back(
1122 rewriter.getAffineDimExpr(curr_src_dim++));
1123 // If the next dim in dst_shape is not 1, treat subsequent dims in
1124 // src_shape which are 1 to be collapsed.
1125 if (curr_dst_dim == dst_shape.size() - 1 ||
1126 dst_shape[curr_dst_dim + 1] != 1) {
1127 while (curr_src_dim < src_shape.size() &&
1128 src_shape[curr_src_dim] == 1) {
1129 reassociation_map[curr_dst_dim].push_back(
1130 rewriter.getAffineDimExpr(curr_src_dim++));
1131 }
1132 }
1133 } else {
1134 is_collapsing_source = false;
1135 break;
1136 }
1137 curr_dst_dim++;
1138 }
1139 // Rank 0 can always use the direct lowering.
1140 if (!src_shape.empty() && !dst_shape.empty() &&
1141 (curr_src_dim != src_shape.size() || curr_dst_dim != dst_shape.size()))
1142 is_collapsing_source = false;
1143
1144 // Otherwise, we need to first reduce all source dimensions into one and
1145 // then expand to the destination dimensions.
1146 if (!is_collapsing_source) {
1147 auto get_identity_exprs = [&rewriter](int n) {
1148 SmallVector<AffineExpr, 4> exprs;
1149 for (int i = 0; i < n; ++i)
1150 exprs.push_back(rewriter.getAffineDimExpr(i));
1151 return exprs;
1152 };
1153 Location loc = reshape_op.getLoc();
1154 int64_t total_elems = std::accumulate(src_shape.begin(), src_shape.end(),
1155 1, std::multiplies<int64_t>());
1156 auto elem_type = operand_type.getElementType();
1157 SmallVector<ReassociationExprs, 4> collapsing_map = {
1158 // Use operand_type here because we need to collapse all operands
1159 // dimensions.
1160 get_identity_exprs(operand_type.getShape().size())};
1161 SmallVector<ReassociationExprs, 4> expanding_map = {
1162 // Use result_type here because we need to expand to all result
1163 // dimensions.
1164 get_identity_exprs(result_type.getShape().size())};
1165
1166 if (isLHLO) {
1167 auto collapsed_type = MemRefType::get({total_elems}, elem_type);
1168 Value collapsed_op = rewriter.create<memref::CollapseShapeOp>(
1169 loc, collapsed_type, args[0], collapsing_map);
1170 Value reshape_buffer = rewriter.create<memref::ExpandShapeOp>(
1171 loc, result_type, collapsed_op, expanding_map);
1172 rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
1173 args[1]);
1174 } else {
1175 auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
1176 Value collapsed_op = rewriter.create<linalg::TensorCollapseShapeOp>(
1177 loc, collapsed_type, args[0], collapsing_map);
1178 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1179 reshape_op, result_type, collapsed_op, expanding_map);
1180 }
1181 return success();
1182 }
1183
1184 bool isCollapsing =
1185 result_type.getRank() < args[0].getType().cast<ShapedType>().getRank();
1186 if (isLHLO) {
1187 Value reshape_buffer = isCollapsing
1188 ? rewriter
1189 .create<memref::CollapseShapeOp>(
1190 reshape_op.getLoc(), result_type,
1191 args[0], reassociation_map)
1192 .getResult()
1193 : rewriter
1194 .create<memref::ExpandShapeOp>(
1195 reshape_op.getLoc(), result_type,
1196 args[0], reassociation_map)
1197 .getResult();
1198 rewriter.replaceOpWithNewOp<linalg::CopyOp>(reshape_op, reshape_buffer,
1199 args[1]);
1200 } else {
1201 if (isCollapsing) {
1202 rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
1203 reshape_op, result_type, args[0], reassociation_map);
1204 } else {
1205 rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
1206 reshape_op, result_type, args[0], reassociation_map);
1207 }
1208 }
1209 return success();
1210 }
1211 };
1212
1213 template <typename OpTy, bool isLHLO = true>
1214 class IotaConverter : public OpConversionPattern<OpTy> {
1215 public:
1216 using OpConversionPattern<OpTy>::OpConversionPattern;
1217
matchAndRewrite(OpTy iota_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1218 LogicalResult matchAndRewrite(
1219 OpTy iota_op, ArrayRef<Value> args,
1220 ConversionPatternRewriter& rewriter) const final {
1221 ShapedType result_shaped_type = GetHloOpResultType<isLHLO>(iota_op);
1222 if (!result_shaped_type) return failure();
1223 result_shaped_type = this->typeConverter->convertType(result_shaped_type)
1224 .template dyn_cast<ShapedType>();
1225
1226 auto result_element_type = result_shaped_type.getElementType();
1227 if (!result_element_type.isSignlessIntOrFloat()) return failure();
1228
1229 // Construct the indexing maps needed for linalg.generic ops.
1230 unsigned nloops = result_shaped_type.getRank();
1231
1232 Location loc = iota_op.getLoc();
1233 // If this is a dynamic iota, the first argument will be a shape tensor.
1234 Value shape_tensor = args.size() > (isLHLO ? 1 : 0) ? args[0] : nullptr;
1235 auto dyn_sizes =
1236 isLHLO
1237 ? SmallVector<Value, 2>()
1238 : ExtractDynamicSizes(
1239 rewriter, loc, GetResultValue<isLHLO>(iota_op), shape_tensor);
1240 auto linalg_op = rewriter.create<linalg::GenericOp>(
1241 loc,
1242 /*resultTensorTypes=*/
1243 isLHLO ? ArrayRef<Type>{} : ArrayRef<Type>{result_shaped_type},
1244 /*inputs=*/ValueRange{},
1245 /*outputBuffers=*/
1246 isLHLO ? ValueRange{args.back()}
1247 : ValueRange{GetInitTensor(rewriter, loc, result_shaped_type,
1248 dyn_sizes)},
1249 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1250 GetNParallelLoopsAttrs(nloops),
1251 [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) {
1252 Value index_op = nested_builder.create<linalg::IndexOp>(
1253 nested_loc, iota_op.iota_dimension());
1254 Value cast_op = nested_builder.create<IndexCastOp>(
1255 nested_loc, index_op,
1256 nested_builder.getIntegerType(
1257 result_element_type.getIntOrFloatBitWidth()));
1258 if (result_element_type.template isa<FloatType>()) {
1259 cast_op = nested_builder.create<SIToFPOp>(nested_loc, cast_op,
1260 result_element_type);
1261 }
1262 nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
1263 });
1264 if (isLHLO)
1265 rewriter.replaceOp(iota_op, llvm::None);
1266 else
1267 rewriter.replaceOp(iota_op, linalg_op.result_tensors());
1268 return success();
1269 }
1270 };
1271
1272 /// Converts mhlo.concatenate operation to a linalg.generic op.
1273 struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
1274 using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
1275
matchAndRewritemlir::__anonb93b5a210111::ConcatenateConverter1276 LogicalResult matchAndRewrite(
1277 mhlo::ConcatenateOp op, ArrayRef<Value> args,
1278 ConversionPatternRewriter& rewriter) const override {
1279 // Shortcut the one-operand case, simplifies code below.
1280 if (args.size() == 1) {
1281 rewriter.replaceOp(op, args[0]);
1282 return success();
1283 }
1284
1285 auto result_type =
1286 this->typeConverter->convertType(op.getResult().getType())
1287 .dyn_cast<RankedTensorType>();
1288 if (!result_type) return failure();
1289
1290 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1291 uint64_t dim = op.dimension();
1292 int64_t rank = result_type.getRank();
1293 Value zero = b.create<ConstantIndexOp>(0);
1294 SmallVector<Value, 3> sizes;
1295 for (int64_t i = 0; i < rank; ++i) {
1296 sizes.push_back(i == dim ? Value() : b.create<tensor::DimOp>(args[0], i));
1297 }
1298
1299 // Calculate the size of the concatenated dimension.
1300 Value result_dim_size;
1301 for (auto arg : args) {
1302 Value size = b.create<tensor::DimOp>(arg, dim);
1303 result_dim_size =
1304 result_dim_size ? b.create<AddIOp>(result_dim_size, size) : size;
1305 }
1306 sizes[dim] = result_dim_size;
1307
1308 // Allocate the output tensor with init_tensor.
1309 SmallVector<Value, 3> dyn_sizes;
1310 for (int64_t i = 0; i < rank; ++i) {
1311 if (result_type.isDynamicDim(i)) dyn_sizes.push_back(sizes[i]);
1312 }
1313 Value result = b.create<linalg::InitTensorOp>(
1314 dyn_sizes, result_type.getShape(), result_type.getElementType());
1315
1316 // Generate a generic op to gather the elements of the concatenate. This is
1317 // awkward standalone but allows fusion with other generic ops.
1318 unsigned nloops = result_type.getRank();
1319 auto linalg_op = b.create<linalg::GenericOp>(
1320 /*resultTensorTypes=*/result_type,
1321 /*inputs=*/ValueRange{}, /*outputBuffers=*/result,
1322 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1323 GetNParallelLoopsAttrs(nloops),
1324 [&](OpBuilder& nested_builder, Location loc, ValueRange) {
1325 OpBuilder b = nested_builder;
1326 Value concat_dim_size = zero;
1327 Value result;
1328
1329 SmallVector<Value, 4> extract_indices;
1330 extract_indices.reserve(nloops);
1331 for (int i = 0; i < nloops; i++) {
1332 extract_indices.push_back(b.create<linalg::IndexOp>(loc, i));
1333 }
1334
1335 Value index_op = b.create<linalg::IndexOp>(loc, dim);
1336 for (const Value& arg : args) {
1337 Value new_concat_dim_size;
1338 scf::IfOp if_op;
1339 if (&arg != &args.back()) {
1340 // Calculate how far along we have iterated along the concatenate
1341 // dimension. That way we can tell which input to select.
1342 new_concat_dim_size = b.create<AddIOp>(
1343 loc, concat_dim_size, b.create<tensor::DimOp>(loc, arg, dim));
1344 Value cmp = b.create<CmpIOp>(loc, rewriter.getI1Type(),
1345 CmpIPredicate::ult, index_op,
1346 new_concat_dim_size);
1347 if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
1348 cmp, true);
1349 if (result) {
1350 b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
1351 } else {
1352 result = if_op->getResults()[0];
1353 }
1354
1355 b = if_op.getThenBodyBuilder(b.getListener());
1356 }
1357
1358 // Now adjust the index for the concatenated dimension to fit into
1359 // the selected tensor and do an extract at that position.
1360 extract_indices[dim] =
1361 b.create<SubIOp>(loc, index_op, concat_dim_size);
1362 Value extract =
1363 b.create<tensor::ExtractOp>(loc, arg, extract_indices);
1364 b.create<scf::YieldOp>(loc, extract);
1365
1366 if (if_op) {
1367 b = if_op.getElseBodyBuilder(b.getListener());
1368 concat_dim_size = new_concat_dim_size;
1369 }
1370 }
1371 nested_builder.create<linalg::YieldOp>(loc, result);
1372 });
1373 rewriter.replaceOp(op, linalg_op.result_tensors());
1374 return success();
1375 }
1376 };
1377
1378 class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> {
1379 public:
1380 using OpConversionPattern::OpConversionPattern;
1381
matchAndRewrite(lmhlo::ConstOp const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const1382 LogicalResult matchAndRewrite(
1383 lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
1384 ConversionPatternRewriter& rewriter) const final {
1385 Location loc = const_op.getLoc();
1386 auto value_attr = const_op.value().cast<DenseElementsAttr>();
1387 if (value_attr.getType().getRank() != 0) return failure();
1388 Value std_scalar_const =
1389 rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({}));
1390 rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const,
1391 const_op.getOperand(), llvm::None);
1392 rewriter.eraseOp(const_op);
1393 return success();
1394 }
1395 };
1396
1397 class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
1398 public:
1399 using OpConversionPattern::OpConversionPattern;
1400
matchAndRewrite(mhlo::ConstOp const_op,ArrayRef<Value>,ConversionPatternRewriter & rewriter) const1401 LogicalResult matchAndRewrite(
1402 mhlo::ConstOp const_op, ArrayRef<Value> /*args*/,
1403 ConversionPatternRewriter& rewriter) const final {
1404 auto value_attr = const_op.value().cast<DenseElementsAttr>();
1405 auto type =
1406 typeConverter->convertType(const_op.getType()).cast<ShapedType>();
1407 if (type != const_op.getType()) {
1408 // Signedness conversion.
1409 value_attr = value_attr.mapValues(type.getElementType(),
1410 [](const APInt& i) { return i; });
1411 }
1412 rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr);
1413 return success();
1414 }
1415 };
1416
1417 class ReduceConverter : public OpConversionPattern<lmhlo::ReduceOp> {
1418 public:
1419 using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
1420
matchAndRewrite(lmhlo::ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1421 LogicalResult matchAndRewrite(
1422 lmhlo::ReduceOp reduce_op, ArrayRef<Value> args,
1423 ConversionPatternRewriter& rewriter) const final {
1424 auto loc = reduce_op.getLoc();
1425 lmhlo::ReduceOp::Adaptor adaptor(args);
1426 auto operand_shape =
1427 adaptor.inputs()[0].getType().template dyn_cast<ShapedType>();
1428 if (!operand_shape || !operand_shape.hasRank()) {
1429 return rewriter.notifyMatchFailure(reduce_op, "expects known-rank args");
1430 }
1431
1432 // First fill the output buffer with the init value.
1433 Value init_value =
1434 rewriter.create<memref::LoadOp>(loc, adaptor.init_values()[0]);
1435 rewriter.create<linalg::FillOp>(loc, init_value, adaptor.out()[0]);
1436
1437 DenseIntElementsAttr dimensions_attr = reduce_op.dimensions();
1438 SmallVector<int, 4> reduction_dims;
1439 for (const auto& dim : dimensions_attr.getIntValues()) {
1440 reduction_dims.push_back(dim.getSExtValue());
1441 }
1442
1443 SmallVector<AffineExpr, 2> src_exprs;
1444 SmallVector<AffineExpr, 2> dst_exprs;
1445 SmallVector<StringRef, 4> types;
1446 for (int i = 0, rank = operand_shape.getRank(); i != rank; ++i) {
1447 bool is_reduced = llvm::is_contained(reduction_dims, i);
1448 types.push_back(is_reduced ? getReductionIteratorTypeName()
1449 : getParallelIteratorTypeName());
1450
1451 src_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1452 if (!is_reduced) {
1453 dst_exprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1454 }
1455 }
1456
1457 auto maps = AffineMap::inferFromExprList({src_exprs, dst_exprs});
1458
1459 auto linalg_op = rewriter.create<linalg::GenericOp>(
1460 loc, /*resultTensorTypes=*/ArrayRef<Type>{},
1461 /*inputs=*/adaptor.inputs(), /*outputBuffers=*/adaptor.out(), maps,
1462 types);
1463 rewriter.inlineRegionBefore(reduce_op.body(), linalg_op.region(),
1464 linalg_op.region().end());
1465 {
1466 OpBuilder::InsertionGuard region_guard(rewriter);
1467 Block* block = linalg_op.getBody();
1468 rewriter.setInsertionPoint(&block->front());
1469
1470 // The incoming region is operating on buffers, while linalg.generic
1471 // expects scalar SSA values. Add some allocs around the original op to
1472 // make it compatible.
1473 auto arg_type = block->getArgument(0).getType().cast<MemRefType>();
1474 Value alloc_a = rewriter.create<memref::AllocaOp>(loc, arg_type);
1475 Value alloc_b = rewriter.create<memref::AllocaOp>(loc, arg_type);
1476 Value alloc_res = rewriter.create<memref::AllocaOp>(loc, arg_type);
1477
1478 // Now turn the existing signature
1479 // (memref<X>, memref<X>, memref<X>) -> ()
1480 // into
1481 // (X, X) -> X
1482 TypeConverter::SignatureConversion signature_converter(3);
1483 signature_converter.remapInput(0, alloc_a);
1484 signature_converter.remapInput(1, alloc_b);
1485 signature_converter.remapInput(2, alloc_res);
1486 signature_converter.addInputs(
1487 {arg_type.getElementType(), arg_type.getElementType()});
1488 Block* entry_block = rewriter.applySignatureConversion(
1489 &linalg_op.region(), signature_converter);
1490
1491 // Store the arguments into the newly allocated buffers.
1492 rewriter.setInsertionPointAfter(alloc_res.getDefiningOp());
1493 rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(0),
1494 alloc_a);
1495 rewriter.create<memref::StoreOp>(loc, entry_block->getArgument(1),
1496 alloc_b);
1497 rewriter.replaceOp(entry_block->getTerminator(), {});
1498
1499 // Load & yield the result.
1500 rewriter.setInsertionPointToEnd(entry_block);
1501 auto load_res = rewriter.create<memref::LoadOp>(loc, alloc_res);
1502 rewriter.create<linalg::YieldOp>(loc, ValueRange{load_res});
1503 }
1504
1505 rewriter.replaceOp(reduce_op, linalg_op.getOperation()->getResults());
1506 return success();
1507 }
1508 };
1509
1510 // TODO(b/156787842): Support the lowering for dynamic shapes.
1511 template <typename OpTy, bool isLHLO = true>
1512 class ReverseConverter
1513 : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1514 isLHLO> {
1515 public:
1516 using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
1517 isLHLO>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)1518 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
1519 auto result_type =
1520 GetHloOpResultType<isLHLO>(op).template cast<ShapedType>();
1521 auto nloops = result_type.getRank();
1522 SmallVector<AffineExpr, 2> input_exprs;
1523 input_exprs.reserve(nloops);
1524 for (int i = 0; i < nloops; ++i)
1525 input_exprs.push_back(b->getAffineDimExpr(i));
1526 for (auto dim : op.dimensions()) {
1527 int i = dim.getZExtValue();
1528 if (result_type.isDynamicDim(i)) return {};
1529 int n = result_type.getShape()[i];
1530 input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
1531 }
1532 return {
1533 AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
1534 b->getMultiDimIdentityMap(nloops)};
1535 }
1536 };
1537
1538 template <typename OpTy, bool isLHLO = true>
1539 class SliceConverter : public OpConversionPattern<OpTy> {
1540 public:
1541 using OpConversionPattern<OpTy>::OpConversionPattern;
1542
matchAndRewrite(OpTy slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1543 LogicalResult matchAndRewrite(
1544 OpTy slice_op, ArrayRef<Value> args,
1545 ConversionPatternRewriter& rewriter) const final {
1546 auto loc = slice_op.getLoc();
1547 auto arg_type = args[0].getType().template dyn_cast<ShapedType>();
1548 if (!arg_type || !arg_type.hasRank()) {
1549 return rewriter.notifyMatchFailure(slice_op, "expects known-rank args");
1550 }
1551
1552 SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1553 for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
1554 auto start = slice_op.start_indices().template getValue<int64_t>(i);
1555 auto limit = slice_op.limit_indices().template getValue<int64_t>(i);
1556 auto stride = slice_op.strides().template getValue<int64_t>(i);
1557 offsets.push_back(rewriter.getI64IntegerAttr(start));
1558 // Say that there are k elements in total, we have condition:
1559 // start + (k - 1) * strides <= limit - 1
1560 // ->
1561 // k <= (limit - 1 - start) / strides + 1
1562 sizes.push_back(
1563 rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1));
1564 strides.push_back(rewriter.getI64IntegerAttr(stride));
1565 }
1566 if (isLHLO) {
1567 auto linalg_op = rewriter.create<memref::SubViewOp>(loc, args[0], offsets,
1568 sizes, strides);
1569 rewriter.create<linalg::CopyOp>(loc, linalg_op, args[1]);
1570 rewriter.eraseOp(slice_op);
1571 } else {
1572 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1573 slice_op, args[0], offsets, sizes, strides);
1574 }
1575 return success();
1576 }
1577 };
1578
1579 class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
1580 public:
1581 using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
1582
matchAndRewrite(mhlo::DynamicSliceOp dynamic_slice_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1583 LogicalResult matchAndRewrite(
1584 mhlo::DynamicSliceOp dynamic_slice_op, ArrayRef<Value> args,
1585 ConversionPatternRewriter& rewriter) const final {
1586 auto loc = dynamic_slice_op.getLoc();
1587 mhlo::DynamicSliceOp::Adaptor adaptor(args);
1588 auto arg_type = adaptor.operand().getType().dyn_cast<ShapedType>();
1589 if (!arg_type || !arg_type.hasRank()) {
1590 return rewriter.notifyMatchFailure(dynamic_slice_op,
1591 "require known-rank args");
1592 }
1593
1594 auto index_type = rewriter.getIndexType();
1595 SmallVector<OpFoldResult, 3> start_indices, sizes;
1596 Value zero = rewriter.create<ConstantOp>(
1597 loc, rewriter.getZeroAttr(adaptor.start_indices()[0]
1598 .getType()
1599 .cast<RankedTensorType>()
1600 .getElementType()));
1601 for (auto en : llvm::enumerate(
1602 llvm::zip(adaptor.start_indices(),
1603 dynamic_slice_op.slice_sizes().getValues<int64_t>()))) {
1604 int64_t size = std::get<1>(en.value());
1605 sizes.push_back(rewriter.getI64IntegerAttr(size));
1606
1607 // By mhlo.DynamicSlice definition:
1608 // `start_indices[i] = clamp(start_indices[i],
1609 // 0, operand.dimension_size[i] - size_indices[i])`
1610 Value start_index =
1611 rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
1612 Value ub = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(),
1613 en.index());
1614 // ClampOp lowering does not support index type, so cast it into integer
1615 // type.
1616 ub = rewriter.createOrFold<IndexCastOp>(loc, start_index.getType(), ub);
1617 ub = rewriter.createOrFold<SubIOp>(
1618 loc, ub,
1619 rewriter.create<ConstantOp>(
1620 loc, rewriter.getIntegerAttr(start_index.getType(), size)));
1621 // TODO(hanchung): This is a workaround to use the method because only
1622 // lmhlo version is defined. The implementation in
1623 // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
1624 // to an lmhlo op and call the lmhlo implementation.
1625 start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
1626 loc, start_index.getType(),
1627 ArrayRef<Type>{start_index.getType(), start_index.getType(),
1628 start_index.getType()},
1629 ArrayRef<Value>{zero, start_index, ub}, &rewriter);
1630 start_indices.push_back(
1631 rewriter.create<IndexCastOp>(loc, index_type, start_index)
1632 .getResult());
1633 }
1634
1635 int64_t rank = arg_type.getRank();
1636 SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1637
1638 auto result_type =
1639 this->typeConverter->convertType(dynamic_slice_op.getType())
1640 .cast<RankedTensorType>();
1641
1642 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1643 dynamic_slice_op, result_type, adaptor.operand(), start_indices, sizes,
1644 strides);
1645 return success();
1646 }
1647 };
1648
1649 class DynamicUpdateSliceConverter
1650 : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1651 public:
1652 using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
1653
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1654 LogicalResult matchAndRewrite(
1655 mhlo::DynamicUpdateSliceOp op, ArrayRef<Value> args,
1656 ConversionPatternRewriter& rewriter) const final {
1657 auto loc = op.getLoc();
1658 mhlo::DynamicUpdateSliceOp::Adaptor adaptor(args);
1659 auto operand_type =
1660 adaptor.operand().getType().dyn_cast<RankedTensorType>();
1661 if (!operand_type || !operand_type.hasStaticShape()) {
1662 return rewriter.notifyMatchFailure(
1663 op, "require static ranked type for operand");
1664 }
1665
1666 auto update_type = adaptor.update().getType().dyn_cast<RankedTensorType>();
1667 if (!update_type || !update_type.hasStaticShape()) {
1668 return rewriter.notifyMatchFailure(
1669 op, "require static ranked type for operand");
1670 }
1671
1672 // We do not have to clamp sizes because the semantic of `update`
1673 // guarantees that it is always in the bounds. See
1674 // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
1675 SmallVector<OpFoldResult, 3> sizes;
1676 for (auto size : update_type.getShape()) {
1677 sizes.push_back(rewriter.getIndexAttr(size));
1678 }
1679
1680 auto index_type = rewriter.getIndexType();
1681 SmallVector<OpFoldResult, 3> start_indices;
1682 Type start_index_type = adaptor.start_indices()[0]
1683 .getType()
1684 .cast<RankedTensorType>()
1685 .getElementType();
1686 Value zero = rewriter.create<ConstantOp>(
1687 loc, rewriter.getZeroAttr(start_index_type));
1688 for (auto en : llvm::enumerate(adaptor.start_indices())) {
1689 // By mhlo.DynamicUpdateSlice definition:
1690 // `start_indices[i] = clamp(start_indices[i],
1691 // 0, operand.dimension_size[i] - update.dimension_size[i])`
1692 Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
1693 Value ub = rewriter.create<ConstantOp>(
1694 loc, rewriter.getIntegerAttr(start_index_type,
1695 operand_type.getDimSize(en.index()) -
1696 update_type.getDimSize(en.index())));
1697 // TODO(hanchung): This is a workaround to use the method because only
1698 // lmhlo version is defined. The implementation in
1699 // map_lmhlo_to_scalar_op.h requires to pass a mhlo op. It will convert it
1700 // to an lmhlo op and call the lmhlo implementation.
1701 start_index = lmhlo::HloOpToStdScalarOp::map<lmhlo::ClampOp>(
1702 loc, start_index_type,
1703 ArrayRef<Type>{start_index_type, start_index_type, start_index_type},
1704 ArrayRef<Value>{zero, start_index, ub}, &rewriter);
1705 start_indices.push_back(
1706 rewriter.create<IndexCastOp>(loc, index_type, start_index)
1707 .getResult());
1708 }
1709
1710 int64_t rank = operand_type.getRank();
1711 SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1712 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1713 op, adaptor.update(), adaptor.operand(), start_indices, sizes, strides);
1714 return success();
1715 }
1716 };
1717
1718 enum class DotOperationType {
1719 kVectorDot = 0,
1720 kMatrixVector = 1,
1721 kMatrixMatrix = 2,
1722 kUnsupported = 3
1723 };
1724
GetDotOperationType(mhlo::DotOp dot_op)1725 DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
1726 ArrayRef<int64_t> lhs_shape =
1727 dot_op.lhs().getType().cast<ShapedType>().getShape();
1728 ArrayRef<int64_t> rhs_shape =
1729 dot_op.rhs().getType().cast<ShapedType>().getShape();
1730 auto shape_matches = [](int64_t a, int64_t b) {
1731 return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1732 a == b;
1733 };
1734 if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
1735 shape_matches(lhs_shape[0], rhs_shape[0])) {
1736 return DotOperationType::kVectorDot;
1737 }
1738 if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
1739 shape_matches(lhs_shape[1], rhs_shape[0])) {
1740 return DotOperationType::kMatrixVector;
1741 }
1742 if (rhs_shape.size() == 2 && rhs_shape.size() == 2 &&
1743 shape_matches(lhs_shape[1], rhs_shape[0])) {
1744 return DotOperationType::kMatrixMatrix;
1745 }
1746 return DotOperationType::kUnsupported;
1747 }
1748
GetDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1749 SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1750 Value lhs, Value rhs,
1751 DotOperationType type) {
1752 SmallVector<Value, 2> dyn_shape;
1753 switch (type) {
1754 case DotOperationType::kMatrixMatrix: {
1755 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1756 dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1757 if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1758 dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1759 break;
1760 }
1761 case DotOperationType::kMatrixVector: {
1762 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1763 dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1764 break;
1765 }
1766 case DotOperationType::kVectorDot:
1767 case DotOperationType::kUnsupported:
1768 default: {
1769 break;
1770 }
1771 }
1772 return dyn_shape;
1773 }
1774
1775 template <DotOperationType op_type, typename LinalgOp>
1776 class DotOpOnTensorsConversion : public OpConversionPattern<mhlo::DotOp> {
1777 public:
1778 using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1779 LogicalResult matchAndRewrite(
1780 mhlo::DotOp op, ArrayRef<Value> args,
1781 ConversionPatternRewriter& rewriter) const final {
1782 if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1783 return failure();
1784 }
1785 if (GetDotOperationType(op) != op_type) return failure();
1786
1787 mhlo::DotOp::Adaptor adaptor(args);
1788
1789 Location loc = op.getLoc();
1790 auto output_type = op.getType().cast<ShapedType>();
1791 auto output_el_type = output_type.getElementType();
1792 auto zero_attr = rewriter.getZeroAttr(output_el_type);
1793 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1794 SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
1795 rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1796 auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
1797 Value zero_tensor =
1798 rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
1799 rewriter.replaceOpWithNewOp<LinalgOp>(
1800 op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
1801 ValueRange{zero_tensor});
1802 return success();
1803 }
1804 };
1805
GetDotGeneralOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,ShapedType result_type)1806 SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
1807 OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
1808 SmallVector<Value, 8> dyn_shape;
1809 if (result_type.isDynamicDim(0))
1810 dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1811 if (result_type.isDynamicDim(1))
1812 dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 1));
1813 if (result_type.isDynamicDim(2))
1814 dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 2));
1815 return dyn_shape;
1816 }
1817
1818 class DotGeneralOpOnTensorsConversion
1819 : public OpConversionPattern<mhlo::DotGeneralOp> {
1820 public:
1821 using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1822 LogicalResult matchAndRewrite(
1823 mhlo::DotGeneralOp op, ArrayRef<Value> args,
1824 ConversionPatternRewriter& rewriter) const final {
1825 if (!VerifyHloOpBufferOrTensorSemantics</*isLHLO=*/false>(op)) {
1826 return failure();
1827 }
1828
1829 mhlo::DotDimensionNumbers dim_numbers = op.dot_dimension_numbers();
1830 auto lhs_bathcing_dims =
1831 Extract1DVector(dim_numbers.lhs_batching_dimensions());
1832 auto rhs_bathcing_dims =
1833 Extract1DVector(dim_numbers.rhs_batching_dimensions());
1834 auto lhs_contracting_dims =
1835 Extract1DVector(dim_numbers.lhs_contracting_dimensions());
1836 auto rhs_contracting_dims =
1837 Extract1DVector(dim_numbers.rhs_contracting_dimensions());
1838 if (lhs_bathcing_dims.size() != 1 || lhs_bathcing_dims[0] != 0) {
1839 return rewriter.notifyMatchFailure(
1840 op, "expected lhs batching dimensions exactly {0}");
1841 }
1842 if (rhs_bathcing_dims.size() != 1 || rhs_bathcing_dims[0] != 0) {
1843 return rewriter.notifyMatchFailure(
1844 op, "expected rhs batching dimensions exactly {0}");
1845 }
1846 if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
1847 return rewriter.notifyMatchFailure(
1848 op, "expected lhs contracting dimensions exactly {2}");
1849 }
1850 if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
1851 return rewriter.notifyMatchFailure(
1852 op, "expected rhs contracting dimensions exactly {1}");
1853 }
1854
1855 mhlo::DotGeneralOp::Adaptor adaptor(args);
1856
1857 Location loc = op.getLoc();
1858 auto output_type = op.getType().cast<ShapedType>();
1859 auto output_el_type = output_type.getElementType();
1860 SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
1861 rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
1862 auto zero_attr = rewriter.getZeroAttr(output_el_type);
1863 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
1864 auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
1865 Value zero_tensor =
1866 rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
1867 Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
1868 loc, /*resultTensorTypes=*/TypeRange{op.getType()},
1869 /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1870 /*outputBuffers=*/ValueRange{zero_tensor});
1871
1872 rewriter.replaceOp(op, linalg_op->getResults());
1873 return success();
1874 }
1875 };
1876
IsInBodyOfLinalgOps(Operation * op)1877 bool IsInBodyOfLinalgOps(Operation* op) {
1878 auto parent_op = op->getParentRegion()->getParentOp();
1879 return parent_op->getDialect() ==
1880 parent_op->getContext()->getLoadedDialect<linalg::LinalgDialect>();
1881 }
1882
1883 template <typename OpTy>
1884 struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
1885 using OpConversionPattern<OpTy>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::ReduceRegionXLAOpConversion1886 LogicalResult matchAndRewrite(
1887 OpTy op, ArrayRef<Value> args,
1888 ConversionPatternRewriter& rewriter) const final {
1889 if (!IsInBodyOfLinalgOps(op)) {
1890 return failure();
1891 }
1892 if (!op.getResult().getType().template isa<TensorType>()) return failure();
1893 if (llvm::all_of(args, [](Value arg) {
1894 return arg.getType().template isa<TensorType>();
1895 })) {
1896 return failure();
1897 }
1898 Value result = lmhlo::HloOpToStdScalarOp::map<OpTy>(
1899 op, getElementTypeOrSelf(op.getType()), args, &rewriter);
1900 rewriter.replaceOp(op, result);
1901 return success();
1902 }
1903 };
1904
GetReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType result_type,ArrayRef<int64_t> reduction_dims)1905 SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
1906 OpBuilder& b, Location loc, Value arg, ShapedType result_type,
1907 ArrayRef<int64_t> reduction_dims) {
1908 llvm::SmallSetVector<int, 4> s;
1909 for (auto dim : reduction_dims) s.insert(dim);
1910
1911 SmallVector<unsigned, 4> parallel_dims;
1912 SmallVector<Value, 8> dyn_shape;
1913 int rank = arg.getType().cast<RankedTensorType>().getRank();
1914 for (int i = 0, j = 0; i < rank; ++i) {
1915 if (s.count(i)) continue;
1916 if (!result_type.isDynamicDim(j++)) continue;
1917 dyn_shape.push_back(b.create<tensor::DimOp>(loc, arg, i));
1918 }
1919
1920 return dyn_shape;
1921 }
1922
1923 class ReduceRegionReturnOpConversion
1924 : public OpConversionPattern<mhlo::ReturnOp> {
1925 public:
1926 using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1927 LogicalResult matchAndRewrite(
1928 mhlo::ReturnOp op, ArrayRef<Value> args,
1929 ConversionPatternRewriter& rewriter) const final {
1930 if (!IsInBodyOfLinalgOps(op)) {
1931 return failure();
1932 }
1933 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, args);
1934 return success();
1935 }
1936 };
1937
1938 class ReduceOnTensorsConversion : public OpConversionPattern<mhlo::ReduceOp> {
1939 public:
1940 using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const1941 LogicalResult matchAndRewrite(
1942 mhlo::ReduceOp op, ArrayRef<Value> args,
1943 ConversionPatternRewriter& rewriter) const final {
1944 Location loc = op.getLoc();
1945 mhlo::ReduceOp::Adaptor adaptor(args);
1946
1947 int num_inputs = static_cast<int>(adaptor.inputs().size());
1948 auto src_type = adaptor.inputs()[0].getType().cast<ShapedType>();
1949 int src_rank = src_type.getRank();
1950 if (!src_rank) {
1951 return rewriter.notifyMatchFailure(op, "expects known-rank args");
1952 }
1953
1954 SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
1955
1956 SmallVector<Value> inputs, outputs;
1957 SmallVector<AffineMap, 3> indexing_maps;
1958 for (int i = 0; i < num_inputs; ++i) {
1959 Value src = adaptor.inputs()[i];
1960 if (src.getType() != src_type) return failure();
1961
1962 // Check if init_value is constant. If so, inline the value into the
1963 // region.
1964 Value init_value = adaptor.init_values()[i];
1965 Attribute init_const_val = GetInitValueAsConst(init_value);
1966 if (init_const_val) {
1967 init_value = rewriter.create<ConstantOp>(
1968 init_value.getDefiningOp()->getLoc(), init_const_val);
1969 } else {
1970 init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
1971 }
1972
1973 inputs.push_back(src);
1974 auto result_type = op.getResult(i).getType().cast<ShapedType>();
1975 SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
1976 rewriter, loc, src, result_type, reduction_dims);
1977 auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape);
1978 Value filled_tensor =
1979 rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
1980 .result();
1981 outputs.push_back(filled_tensor);
1982 }
1983
1984 // Prepare indexing maps for linalg generic op. The elements are for src
1985 // and dst. Transpose `src` to make the reduction loops be the innermost,
1986 // because it's easier to fully utilize processors.
1987 indexing_maps.append(
1988 num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank,
1989 reduction_dims));
1990
1991 // The indexing map of `dst` should drop the reduction loops. Since the
1992 // reduction loops now are all in the innermost, drops
1993 // `reduction_dims.size()` dimensions. We don't need an inverse
1994 // permutation here because they are the same.
1995 SmallVector<AffineExpr, 4> exprs;
1996 for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
1997 exprs.push_back(rewriter.getAffineDimExpr(i));
1998 indexing_maps.append(num_inputs,
1999 AffineMap::get(src_rank, /*symbolCount=*/0, exprs,
2000 rewriter.getContext()));
2001
2002 auto linalg_op = rewriter.create<linalg::GenericOp>(
2003 loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
2004 /*outputBuffers=*/ValueRange{outputs}, indexing_maps,
2005 GetParallelAndReductionIterators(src_rank, reduction_dims.size()));
2006
2007 // Convert the signature of the body. The reduce op region apply function
2008 // has a signature (lhs, rhs) -> output, all of the same tensor type t.
2009 // This is converted to a function with the same signature but with
2010 // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
2011 // be converted to "(f32, f32, f32)".
2012 Region& region = linalg_op.region();
2013 rewriter.inlineRegionBefore(op.body(), region, region.end());
2014 TypeConverter::SignatureConversion signature_converter(num_inputs * 2);
2015 for (int i = 0; i < num_inputs * 2; ++i)
2016 signature_converter.addInputs(i, src_type.getElementType());
2017 rewriter.applySignatureConversion(®ion, signature_converter);
2018 rewriter.replaceOp(op, linalg_op.getResults());
2019 return success();
2020 }
2021 };
2022
2023 /// Converts mhlo.pad operation to linalg.pad_tensor op.
2024 struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
2025 using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
2026
matchAndRewritemlir::__anonb93b5a210111::PadOpOnTensorsConversion2027 LogicalResult matchAndRewrite(
2028 mhlo::PadOp op, ArrayRef<Value> args,
2029 ConversionPatternRewriter& rewriter) const override {
2030 mhlo::PadOp::Adaptor adaptor(args);
2031 if (llvm::any_of(
2032 op.interior_padding().getValues<APInt>(),
2033 [](const APInt& int_val) { return int_val.getZExtValue() != 0; })) {
2034 return rewriter.notifyMatchFailure(op, "expected no interior padding");
2035 }
2036
2037 auto loc = op.getLoc();
2038 Value padding_val =
2039 rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
2040
2041 const auto& edge_padding_low = op.edge_padding_low();
2042 const auto& edge_padding_high = op.edge_padding_high();
2043 SmallVector<OpFoldResult, 4> low, high;
2044 for (auto it : llvm::zip(edge_padding_low, edge_padding_high)) {
2045 low.push_back(rewriter.createOrFold<ConstantIndexOp>(
2046 loc, std::get<0>(it).getZExtValue()));
2047 high.push_back(rewriter.createOrFold<ConstantIndexOp>(
2048 loc, std::get<1>(it).getZExtValue()));
2049 }
2050 Type result_type = op.getResult().getType();
2051 auto pad_tensor_op = linalg::PadTensorOp::createPadScalarOp(
2052 result_type, adaptor.operand(), padding_val, low, high, loc, rewriter);
2053 rewriter.replaceOp(op, pad_tensor_op.getResult());
2054 return success();
2055 }
2056 };
2057
2058 /// Converts mhlo.conv operation to linalg named op. This only covers normal
2059 /// convolution cases. The op must have canonical dimension numbers. Depthwise
2060 /// convolution and pointwise convolution are not handled in the conversion.
2061 struct NormalConvOpOnTensorsConversion
2062 : public OpConversionPattern<mhlo::ConvOp> {
2063 using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
2064
matchAndRewritemlir::__anonb93b5a210111::NormalConvOpOnTensorsConversion2065 LogicalResult matchAndRewrite(
2066 mhlo::ConvOp op, ArrayRef<Value> args,
2067 ConversionPatternRewriter& rewriter) const override {
2068 if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
2069 if (op.feature_group_count() != 1u) return failure();
2070
2071 mhlo::ConvOp::Adaptor adaptor(args);
2072 Location loc = op.getLoc();
2073 Value input = adaptor.lhs();
2074 Value filter = adaptor.rhs();
2075 auto result_type = op.getResult().getType().cast<ShapedType>();
2076 int64_t rank = result_type.getRank();
2077
2078 // Check if padding is zero.
2079 DenseIntElementsAttr padding = op.paddingAttr();
2080 if (padding && !isSplatValue(*op.padding(), 0)) {
2081 return rewriter.notifyMatchFailure(op, "expected no padding");
2082 }
2083
2084 // The output shape is N spatial_dims F.
2085 SmallVector<Value, 8> dyn_sizes;
2086 if (result_type.isDynamicDim(0)) {
2087 dyn_sizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
2088 }
2089 for (int64_t i = 1, e = rank - 1; i < e; ++i) {
2090 if (result_type.isDynamicDim(i)) {
2091 return rewriter.notifyMatchFailure(
2092 op, "expected output spatial dims to be static shapes");
2093 }
2094 }
2095 if (result_type.isDynamicDim(rank - 1)) {
2096 dyn_sizes.push_back(
2097 rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
2098 }
2099 Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2100 loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
2101 auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2102 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2103 Value zero_tensor =
2104 rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2105 linalg::LinalgOp res;
2106 Attribute strides = op.window_stridesAttr();
2107 // TODO(ataei): Only support dilated kernel right now. We need to consider
2108 // input dilation for deconvolution cases.
2109 Attribute dilations = op.rhs_dilationAttr();
2110 switch (rank) {
2111 case 3: {
2112 res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
2113 loc, result_type, ValueRange{input, filter},
2114 ValueRange{zero_tensor}, dilations, strides);
2115 break;
2116 }
2117 case 4: {
2118 res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
2119 loc, result_type, ValueRange{input, filter},
2120 ValueRange{zero_tensor}, dilations, strides);
2121 break;
2122 }
2123 case 5: {
2124 res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
2125 loc, result_type, ValueRange{input, filter},
2126 ValueRange{zero_tensor}, dilations, strides);
2127 break;
2128 }
2129 default:
2130 return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
2131 }
2132 rewriter.replaceOp(op, res.getOperation()->getResults());
2133 return success();
2134 }
2135 };
2136
2137 /// Converts mhlo.convolution operation to
2138 /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
2139 /// depthwise_conv_2d_input_nhwc_filter_hwc op.
2140 struct DepthwiseConvOpOnTensorsConversion
2141 : public OpConversionPattern<mhlo::ConvOp> {
2142 using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
2143
matchAndRewritemlir::__anonb93b5a210111::DepthwiseConvOpOnTensorsConversion2144 LogicalResult matchAndRewrite(
2145 mhlo::ConvOp op, ArrayRef<Value> args,
2146 ConversionPatternRewriter& rewriter) const override {
2147 if (op.batch_group_count() != 1) return failure();
2148
2149 if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2150 return rewriter.notifyMatchFailure(op,
2151 "non-zero padding unsupported yet");
2152 }
2153
2154 if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1))) {
2155 return rewriter.notifyMatchFailure(
2156 op, "non-one lhs- dialation unsupported yet");
2157 }
2158
2159 if (const mhlo::ConvDimensionNumbers& dimension_numbers =
2160 op.dimension_numbers()) {
2161 // Make sure that this is 2-D convolution.
2162 const auto spatial_rank =
2163 llvm::size(dimension_numbers.input_spatial_dimensions());
2164 if (spatial_rank != 2) {
2165 return rewriter.notifyMatchFailure(op,
2166 "only support 2-D cases for now");
2167 }
2168
2169 // Make sure that this is depthwise convolution.
2170 int64_t input_feature_dim =
2171 dimension_numbers.input_feature_dimension().getInt();
2172 int64_t input_feature_count =
2173 op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
2174 if (op.feature_group_count() != input_feature_count) {
2175 return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
2176 }
2177
2178 // Make sure that this convolution has a canonical form.
2179 if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
2180 return rewriter.notifyMatchFailure(op, "does not have canonical form");
2181 }
2182 }
2183
2184 DenseIntElementsAttr window_strides;
2185 if (op.window_strides()) {
2186 window_strides = op.window_strides().getValue();
2187 } else {
2188 window_strides = rewriter.getI64VectorAttr({1, 1});
2189 }
2190
2191 DenseIntElementsAttr rhs_dilation;
2192 if (op.rhs_dilation()) {
2193 rhs_dilation = op.rhs_dilation().getValue();
2194 } else {
2195 rhs_dilation = rewriter.getI64VectorAttr({1, 1});
2196 }
2197
2198 mhlo::ConvOp::Adaptor adaptor(args);
2199 Location loc = op.getLoc();
2200 Value input = adaptor.lhs();
2201 Value filter = adaptor.rhs();
2202 auto result_type = op.getResult().getType().cast<RankedTensorType>();
2203 if (!result_type.hasStaticShape()) {
2204 return rewriter.notifyMatchFailure(op,
2205 "expected output has static shapes");
2206 }
2207
2208 auto filter_dims =
2209 llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
2210
2211 auto get_indices_vector = [](int start, int end) {
2212 return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
2213 };
2214
2215 if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
2216 // For cases where channel multiplier != 1
2217 auto output_dims = result_type.getShape();
2218 auto channel_multiplier = filter_dims[3];
2219 SmallVector<int64_t> reshaped_output_dims;
2220 reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
2221 reshaped_output_dims.push_back(channel_multiplier);
2222 reshaped_output_dims[3] /= channel_multiplier;
2223
2224 Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2225 loc, reshaped_output_dims, result_type.getElementType());
2226 auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2227 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2228 Value zero_tensor =
2229 rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2230
2231 auto reshaped_output_type = RankedTensorType::get(
2232 reshaped_output_dims, result_type.getElementType());
2233 auto conv = rewriter.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
2234 op.getLoc(), reshaped_output_type, ValueRange{input, filter},
2235 ValueRange{zero_tensor}, rhs_dilation, window_strides);
2236
2237 // Create a Linalg reshape op that converts the output from 5 dimensions
2238 // into 4 dimensions (by collapsing the last two dimensions). This is
2239 // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
2240 // 5 dimensions for the output.
2241 SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
2242 get_indices_vector(0, 1), get_indices_vector(1, 2),
2243 get_indices_vector(2, 3), get_indices_vector(3, 5)};
2244 rewriter.replaceOpWithNewOp<linalg::TensorCollapseShapeOp>(
2245 op, result_type, conv.getResult(0), collapsed_dim_list);
2246 } else {
2247 // For cases where channel multiplier == 1
2248 Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2249 loc, result_type.getShape(), result_type.getElementType());
2250 auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
2251 Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
2252 Value zero_tensor =
2253 rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
2254
2255 // Create a Linalg reshape op that converts the filter from 4 dimensions
2256 // into 3 dimensions (by droping the unit dimension). This is needed
2257 // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
2258 // dimensions for the filter.
2259
2260 filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
2261 filter_dims.pop_back();
2262
2263 RankedTensorType filter_shape =
2264 RankedTensorType::get(filter_dims, op.getType().getElementType());
2265
2266 SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
2267 get_indices_vector(0, 1), get_indices_vector(1, 2),
2268 get_indices_vector(2, 4)};
2269
2270 Value reshaped_filter = rewriter.create<linalg::TensorCollapseShapeOp>(
2271 loc, filter_shape, filter, collapsed_dim_list);
2272
2273 rewriter.replaceOpWithNewOp<linalg::DepthwiseConvInputNHWCFilterHWCOp>(
2274 op, result_type, ValueRange{input, reshaped_filter},
2275 ValueRange{zero_tensor}, rhs_dilation, window_strides);
2276 }
2277
2278 return success();
2279 }
2280 };
2281
2282 struct ReduceWindowOpOnTensorsConversion
2283 : public OpConversionPattern<mhlo::ReduceWindowOp> {
2284 using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
2285
2286 /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
2287 /// the pooling is determined based on the body of the reduce window
2288 /// operation. This class enumerates the different variants.
2289 enum class PoolingType {
2290 kInvalid,
2291 k2DMin,
2292 k3DMin,
2293 k2DMax,
2294 k3DMax,
2295 k2DAdd,
2296 k3DAdd,
2297 };
2298
getPoolingTypemlir::__anonb93b5a210111::ReduceWindowOpOnTensorsConversion2299 static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op,
2300 int result_index) {
2301 auto rank =
2302 reduce_op.getResultTypes()[result_index].cast<ShapedType>().getRank();
2303 if (Operation* op = reduce_op.getReductionOp(result_index)) {
2304 if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
2305 if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
2306 if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
2307 if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
2308 if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
2309 if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
2310 }
2311 return PoolingType::kInvalid;
2312 }
2313
matchAndRewritemlir::__anonb93b5a210111::ReduceWindowOpOnTensorsConversion2314 LogicalResult matchAndRewrite(
2315 mhlo::ReduceWindowOp op, ArrayRef<Value> args,
2316 ConversionPatternRewriter& rewriter) const override {
2317 auto loc = op.getLoc();
2318 int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
2319 if (rank != 4 && rank != 5) {
2320 return rewriter.notifyMatchFailure(
2321 op, "expected NHWC/NDHWC pooling-based op");
2322 }
2323
2324 if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2325 return rewriter.notifyMatchFailure(op, "require paddings are all zero");
2326 }
2327
2328 int last_dim = rank - 1;
2329 SmallVector<int64_t, 2> fake_window_shapes;
2330 for (int i = 1; i < last_dim; ++i) {
2331 fake_window_shapes.push_back(op.window_dimensions().getValue<int64_t>(i));
2332 }
2333
2334 if (op.window_strides() &&
2335 (op.window_strides().getValue().getValue<int64_t>(0) != 1 ||
2336 op.window_strides().getValue().getValue<int64_t>(last_dim) != 1)) {
2337 return rewriter.notifyMatchFailure(
2338 op, "expected window_strides to be [1,x,y,(z),1]");
2339 }
2340 if (op.window_dimensions() &&
2341 (op.window_dimensions().getValue<int64_t>(0) != 1 ||
2342 op.window_dimensions().getValue<int64_t>(last_dim) != 1)) {
2343 return rewriter.notifyMatchFailure(
2344 op, "expected window_dimensions to be [1,x,y,(z),1]");
2345 }
2346
2347 Attribute strides;
2348 SmallVector<int64_t> vec;
2349 if (op.window_stridesAttr()) {
2350 for (int i = 1; i < last_dim; ++i) {
2351 vec.push_back(op.window_strides().getValue().getValue<int64_t>(i));
2352 }
2353 } else {
2354 vec.assign(rank - 2, 1);
2355 }
2356 strides = rewriter.getI64VectorAttr(vec);
2357
2358 Attribute dilations;
2359 vec.clear();
2360 if (op.window_dilations()) {
2361 for (int i = 1; i < last_dim; ++i) {
2362 vec.push_back(op.window_dilations().getValue().getValue<int64_t>(i));
2363 }
2364 } else {
2365 vec.assign(rank - 2, 1);
2366 }
2367 dilations = rewriter.getI64VectorAttr(vec);
2368
2369 SmallVector<Value> pooling_ops;
2370
2371 ArrayRef<Value> inputs = args.take_front(op.inputs().size());
2372 ArrayRef<Value> init_values = args.drop_front(op.inputs().size());
2373 for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
2374 OpResult result = std::get<0>(it);
2375 Value input = std::get<1>(it);
2376 Value init_value = std::get<2>(it);
2377 auto result_type = result.getType().cast<ShapedType>();
2378 if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
2379 return rewriter.notifyMatchFailure(op,
2380 "expected element type to be f32");
2381 }
2382
2383 // Create a fake window dimension.
2384 auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
2385 loc, fake_window_shapes, result_type.getElementType());
2386
2387 SmallVector<Value> result_dynamic_dims;
2388 for (auto& en : llvm::enumerate(result_type.getShape())) {
2389 if (en.value() != ShapedType::kDynamicSize) continue;
2390 Value dim_size = rewriter.create<tensor::DimOp>(loc, input, en.index());
2391 if (en.index() == 0 || en.index() == rank - 1) {
2392 // batch dims and channel dims can be derived from input dims
2393 // directly.
2394 result_dynamic_dims.push_back(dim_size);
2395 } else {
2396 auto i = en.index() - 1;
2397 auto stride =
2398 strides.cast<DenseIntElementsAttr>().getValue<int64_t>(i);
2399 auto dilation =
2400 dilations.cast<DenseIntElementsAttr>().getValue<int64_t>(i);
2401 // let j = i * stride
2402 // output[i] = reduce( input[j, j + window_size * dilation) )
2403 Value offset = rewriter.create<ConstantIndexOp>(
2404 loc, fake_window_shapes[i] * dilation);
2405 dim_size = rewriter.create<SubIOp>(loc, dim_size, offset);
2406 dim_size = rewriter.create<UnsignedDivIOp>(
2407 loc, dim_size, rewriter.create<ConstantIndexOp>(loc, stride));
2408 dim_size = rewriter.create<AddIOp>(
2409 loc, dim_size, rewriter.create<ConstantIndexOp>(loc, 1));
2410 result_dynamic_dims.push_back(dim_size);
2411 }
2412 }
2413 Value init_tensor = rewriter.create<linalg::InitTensorOp>(
2414 loc, result_dynamic_dims, result_type.getShape(),
2415 result_type.getElementType());
2416
2417 init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
2418 Value filled_init_tensor =
2419 rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
2420 .getResult(0);
2421 auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
2422 return cast<linalg::LinalgOp>(
2423 rewriter
2424 .create<std::remove_pointer_t<decltype(type_ptr)>>(
2425 loc, ArrayRef<Type>{result_type},
2426 ValueRange{input, fake_window_dims.getResult()},
2427 filled_init_tensor, strides, dilations)
2428 .getOperation());
2429 };
2430 linalg::LinalgOp pooling_op;
2431 PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
2432 switch (pooling_type) {
2433 case PoolingType::k2DMin: {
2434 pooling_op =
2435 create_op(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
2436 break;
2437 }
2438 case PoolingType::k3DMin: {
2439 pooling_op =
2440 create_op(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
2441 break;
2442 }
2443 case PoolingType::k2DMax: {
2444 pooling_op =
2445 create_op(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
2446 break;
2447 }
2448 case PoolingType::k3DMax: {
2449 pooling_op =
2450 create_op(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
2451 break;
2452 }
2453 case PoolingType::k2DAdd: {
2454 pooling_op =
2455 create_op(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
2456 break;
2457 }
2458 case PoolingType::k3DAdd: {
2459 pooling_op =
2460 create_op(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
2461 break;
2462 }
2463 case PoolingType::kInvalid:
2464 return rewriter.notifyMatchFailure(op, "unknown reduction operation");
2465 }
2466 pooling_ops.push_back(pooling_op->getResult(0));
2467 }
2468 rewriter.replaceOp(op, pooling_ops);
2469 return success();
2470 }
2471 };
2472
2473 /// Converts xla-hlo.torch_index_select op to a linalg.generic op.
2474 struct TorchIndexSelectOpOnTensorsConversion
2475 : public OpConversionPattern<mhlo::TorchIndexSelectOp> {
2476 using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
2477
matchAndRewritemlir::__anonb93b5a210111::TorchIndexSelectOpOnTensorsConversion2478 LogicalResult matchAndRewrite(
2479 mhlo::TorchIndexSelectOp op, ArrayRef<Value> args,
2480 ConversionPatternRewriter& rewriter) const final {
2481 mhlo::TorchIndexSelectOp::Adaptor adaptor(args);
2482 int axis = static_cast<int>(op.dim());
2483 int batch = static_cast<int>(op.batch_dims());
2484 auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
2485 int num_indices = static_cast<int>(index_shaped_type.getRank());
2486 auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
2487 if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
2488 if (batch < 0) batch += num_indices;
2489
2490 Location loc = op.getLoc();
2491 auto result_type =
2492 this->typeConverter->convertType(op.getResult().getType())
2493 .cast<ShapedType>();
2494 int rank = static_cast<int>(result_type.getRank());
2495
2496 SmallVector<AffineMap, 2> indexing_maps;
2497 SmallVector<AffineExpr, 4> exprs;
2498 for (int i = 0; i < batch; ++i) {
2499 exprs.push_back(rewriter.getAffineDimExpr(i));
2500 }
2501 for (int i = 0, e = num_indices - batch; i < e; ++i) {
2502 exprs.push_back(rewriter.getAffineDimExpr(axis + i));
2503 }
2504 indexing_maps.emplace_back(
2505 AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
2506 indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
2507
2508 // The output shape is
2509 // `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
2510 SmallVector<Value, 4> dyn_sizes;
2511 for (int i = 0; i < rank; ++i) {
2512 if (!result_type.isDynamicDim(i)) continue;
2513 if (i < axis) {
2514 dyn_sizes.push_back(
2515 rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
2516 } else if (i < (axis + num_indices - batch)) {
2517 int idx = i - axis + batch;
2518 dyn_sizes.push_back(
2519 rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
2520 } else {
2521 int idx = i - (axis + num_indices - batch) + axis + 1;
2522 dyn_sizes.push_back(
2523 rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
2524 }
2525 }
2526 Value init_op = rewriter.create<linalg::InitTensorOp>(
2527 loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
2528 auto linalg_op = rewriter.create<linalg::GenericOp>(
2529 loc, /*resultTensors=*/ArrayRef<Type>{result_type},
2530 /*inputs=*/adaptor.index(),
2531 /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank));
2532
2533 SmallVector<Type, 4> body_arg_types;
2534 SmallVector<Value, 2> linalg_op_args = {adaptor.index()};
2535 // Add a block to the region.
2536 auto* region = &linalg_op.region();
2537 auto* block = rewriter.createBlock(region, region->end());
2538 for (auto block_args : linalg_op_args) {
2539 body_arg_types.push_back(
2540 block_args.getType().cast<ShapedType>().getElementType());
2541 }
2542 block->addArguments(body_arg_types);
2543 block->addArguments(result_type.getElementType());
2544 OpBuilder::InsertionGuard guard(rewriter);
2545 rewriter.setInsertionPointToEnd(block);
2546
2547 Value casted_value = rewriter.create<IndexCastOp>(
2548 loc, block->getArgument(0), rewriter.getIndexType());
2549
2550 SmallVector<Value, 4> indices;
2551 for (int i = 0; i < axis; ++i) {
2552 indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2553 }
2554 indices.push_back(casted_value);
2555 for (int i = axis + num_indices - batch; i < rank; ++i) {
2556 indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2557 }
2558 Value res =
2559 rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
2560 rewriter.create<linalg::YieldOp>(loc, res);
2561
2562 rewriter.replaceOp(op, linalg_op.getResults());
2563 return success();
2564 }
2565 };
2566
2567 struct ScatterUpdateOnTensorsConversion
2568 : public OpConversionPattern<mhlo::ScatterOp> {
2569 using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
2570
matchAndRewritemlir::__anonb93b5a210111::ScatterUpdateOnTensorsConversion2571 LogicalResult matchAndRewrite(
2572 mhlo::ScatterOp op, ArrayRef<Value> args,
2573 ConversionPatternRewriter& rewriter) const final {
2574 mhlo::ScatterOp::Adaptor adaptor(args);
2575
2576 // Check if it is a tensor_scatter_nd_update-like op.
2577 auto& body_ops = op.getRegion().front().getOperations();
2578 if (body_ops.size() != 1) return failure();
2579 auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
2580 if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
2581
2582 auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
2583 auto indices_ty =
2584 adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
2585 if (!operand_ty || !indices_ty) return failure();
2586
2587 // Linalg operations put all the computation to the innermost loop. Since we
2588 // also iterate over scatter_indices() with some loops, we can only check
2589 // one scatter index in one iteration. If there are multiple indices (ie,
2590 // the index depth is greater than 1), we don't have a way to keep the
2591 // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
2592 // we should use the update value only if (i == 0 and j == 1). However, we
2593 // can not get both indices in one iteration unless we pack them together.
2594 auto index_vector_dim =
2595 op.scatter_dimension_numbers().index_vector_dim().getInt();
2596 if (indices_ty.getDimSize(index_vector_dim) != 1)
2597 return rewriter.notifyMatchFailure(op, "require index depth to be 1");
2598 if (index_vector_dim != indices_ty.getRank() - 1) {
2599 return rewriter.notifyMatchFailure(
2600 op, "require index_vector_dim to be the last dim");
2601 }
2602
2603 // One of indices dims is index depth vector.
2604 int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
2605 SmallVector<AffineMap, 3> indexing_maps;
2606 {
2607 SmallVector<AffineExpr> exprs;
2608 for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
2609 exprs.push_back(rewriter.getAffineDimExpr(i));
2610 indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2611 rewriter.getContext()));
2612 }
2613 {
2614 SmallVector<AffineExpr> exprs;
2615 for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
2616 exprs.push_back(rewriter.getAffineDimExpr(i));
2617 // The index depth is 1.
2618 exprs.push_back(rewriter.getAffineConstantExpr(0));
2619 indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2620 rewriter.getContext()));
2621
2622 exprs.pop_back();
2623 auto update_window_dims =
2624 Extract1DVector(op.scatter_dimension_numbers().update_window_dims());
2625 for (auto d : update_window_dims)
2626 exprs.push_back(rewriter.getAffineDimExpr(d));
2627 indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
2628 rewriter.getContext()));
2629 }
2630 indexing_maps.push_back(indexing_maps.front());
2631
2632 auto result_ty = this->typeConverter->convertType(op.getResult().getType())
2633 .cast<ShapedType>();
2634 auto scatter_dims_to_operand_dims = Extract1DVector(
2635 op.scatter_dimension_numbers().scatter_dims_to_operand_dims());
2636 assert(scatter_dims_to_operand_dims.size() == 1);
2637 // Do not need init_tensor because we'd like to initialize the output as
2638 // operand.
2639 auto linalg_op = rewriter.create<linalg::GenericOp>(
2640 op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
2641 /*inputs=*/
2642 ValueRange{adaptor.operand(), adaptor.scatter_indices(),
2643 adaptor.updates()},
2644 /*outputs=*/adaptor.operand(), indexing_maps,
2645 GetNParallelLoopsAttrs(nloops),
2646 [&](OpBuilder& b, Location loc, ValueRange args) {
2647 Value cmp_idx =
2648 b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
2649 Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]);
2650 Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq,
2651 cmp_idx, idx);
2652 // Use the output arg, so some update values won't be init value
2653 // again.
2654 Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2],
2655 args[3]);
2656 b.create<linalg::YieldOp>(loc, res);
2657 });
2658 rewriter.replaceOp(op, linalg_op.getResults());
2659 return success();
2660 }
2661 };
2662
populateLHLOToLinalgConversionPattern(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList * patterns)2663 void populateLHLOToLinalgConversionPattern(MLIRContext* context,
2664 TypeConverter& typeConverter,
2665 OwningRewritePatternList* patterns) {
2666 // clang-format off
2667 patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
2668 ConstConverterBuffer,
2669 ConvToLinalgConverter,
2670 IotaConverter<lmhlo::IotaOp>,
2671 LhloBroadcastInDimConverter,
2672 PointwiseToLinalgConverter<lmhlo::AbsOp>,
2673 PointwiseToLinalgConverter<lmhlo::AddOp>,
2674 PointwiseToLinalgConverter<lmhlo::AndOp>,
2675 PointwiseToLinalgConverter<lmhlo::Atan2Op>,
2676 PointwiseToLinalgConverter<lmhlo::CeilOp>,
2677 PointwiseToLinalgConverter<lmhlo::ClampOp>,
2678 PointwiseToLinalgConverter<lmhlo::CompareOp>,
2679 PointwiseToLinalgConverter<lmhlo::ComplexOp>,
2680 PointwiseToLinalgConverter<lmhlo::ConvertOp>,
2681 // TODO(ataei): Remove this pattern, CopyOp is folded away.
2682 PointwiseToLinalgConverter<lmhlo::CopyOp>,
2683 PointwiseToLinalgConverter<lmhlo::CosOp>,
2684 PointwiseToLinalgConverter<lmhlo::DivOp>,
2685 PointwiseToLinalgConverter<lmhlo::ExpOp>,
2686 PointwiseToLinalgConverter<lmhlo::Expm1Op>,
2687 PointwiseToLinalgConverter<lmhlo::FloorOp>,
2688 PointwiseToLinalgConverter<lmhlo::ImagOp>,
2689 PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
2690 PointwiseToLinalgConverter<lmhlo::LogOp>,
2691 PointwiseToLinalgConverter<lmhlo::LogisticOp>,
2692 PointwiseToLinalgConverter<lmhlo::Log1pOp>,
2693 PointwiseToLinalgConverter<lmhlo::MaxOp>,
2694 PointwiseToLinalgConverter<lmhlo::MinOp>,
2695 PointwiseToLinalgConverter<lmhlo::MulOp>,
2696 PointwiseToLinalgConverter<lmhlo::NegOp>,
2697 PointwiseToLinalgConverter<lmhlo::NotOp>,
2698 PointwiseToLinalgConverter<lmhlo::OrOp>,
2699 PointwiseToLinalgConverter<lmhlo::PowOp>,
2700 PointwiseToLinalgConverter<lmhlo::RealOp>,
2701 PointwiseToLinalgConverter<lmhlo::RemOp>,
2702 PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
2703 PointwiseToLinalgConverter<lmhlo::SelectOp>,
2704 PointwiseToLinalgConverter<lmhlo::ShiftLeftOp>,
2705 PointwiseToLinalgConverter<lmhlo::ShiftRightArithmeticOp>,
2706 PointwiseToLinalgConverter<lmhlo::ShiftRightLogicalOp>,
2707 PointwiseToLinalgConverter<lmhlo::SignOp>,
2708 PointwiseToLinalgConverter<lmhlo::SinOp>,
2709 PointwiseToLinalgConverter<lmhlo::SqrtOp>,
2710 PointwiseToLinalgConverter<lmhlo::SubOp>,
2711 PointwiseToLinalgConverter<lmhlo::TanhOp>,
2712 PointwiseToLinalgConverter<lmhlo::XorOp>,
2713 ReduceConverter,
2714 ReshapeOpConverter<lmhlo::ReshapeOp>,
2715 ReverseConverter<lmhlo::ReverseOp>,
2716 ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
2717 ScalarPointwiseToStandardConverter<lmhlo::MaxOp>,
2718 SliceConverter<lmhlo::SliceOp>,
2719 TransposeConverter<lmhlo::TransposeOp>
2720 >(typeConverter, context);
2721 // clang-format on
2722 }
2723
2724 struct ComputeReshapeShapeConversion
2725 : public OpConversionPattern<mhlo::ComputeReshapeShapeOp> {
2726 using OpConversionPattern<mhlo::ComputeReshapeShapeOp>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::ComputeReshapeShapeConversion2727 LogicalResult matchAndRewrite(
2728 mhlo::ComputeReshapeShapeOp op, ArrayRef<Value> args,
2729 ConversionPatternRewriter& rewriter) const final {
2730 auto loc = op.getLoc();
2731 auto ctx = op->getContext();
2732 Value neg_one = rewriter.create<ConstantIndexOp>(loc, -1);
2733 auto index_type = rewriter.getIndexType();
2734 auto num_elements = args[0];
2735 auto target_shape_type = args[1].getType().cast<ShapedType>();
2736 auto extent_type =
2737 shape::getExtentTensorType(ctx, target_shape_type.getDimSize(0));
2738
2739 // Calculate the computed actual extent for a possible dynamic extent.
2740 auto new_shape =
2741 target_shape_type.getElementType().isIndex()
2742 ? args[1]
2743 : rewriter.create<IndexCastOp>(loc, extent_type, args[1]);
2744 Value new_shape_rank =
2745 rewriter.create<shape::RankOp>(loc, index_type, new_shape);
2746 // The product begins with a -1 seed which will cancel out a -1 extent in
2747 // the input shape if there is one. If there is not, this computed result
2748 // will never be used, so it's okay to compute a negative number of
2749 // elements.
2750 auto accounted_num_els =
2751 rewriter.create<shape::ReduceOp>(loc, new_shape, neg_one);
2752 {
2753 PatternRewriter::InsertionGuard g(rewriter);
2754 rewriter.setInsertionPointToEnd(accounted_num_els.getBody());
2755 Value lhs = accounted_num_els.getBody()->getArgument(1);
2756 Value rhs = accounted_num_els.getBody()->getArgument(2);
2757 rewriter.create<shape::YieldOp>(
2758 loc, rewriter.create<MulIOp>(loc, lhs, rhs).result());
2759 }
2760 Value missing_dim_val = rewriter.create<UnsignedDivIOp>(
2761 loc, num_elements, accounted_num_els->getResult(0));
2762
2763 // Create the final target shape with a possible dynamic extent replace with
2764 // the calculated extent.
2765 SmallVector<Value> dynamic_extent;
2766 if (!target_shape_type.hasStaticShape())
2767 dynamic_extent.push_back(new_shape_rank);
2768 auto gen = rewriter.create<tensor::GenerateOp>(
2769 loc, target_shape_type, dynamic_extent,
2770 [&](OpBuilder& b, Location loc, ValueRange indices) {
2771 Value extent = b.create<shape::GetExtentOp>(loc, index_type,
2772 new_shape, indices[0]);
2773 Value use_missing_dim_val =
2774 b.create<CmpIOp>(loc, CmpIPredicate::eq, extent, neg_one);
2775 Value dim_val = b.create<SelectOp>(loc, use_missing_dim_val,
2776 missing_dim_val, extent);
2777 dim_val = target_shape_type.getElementType().isIndex()
2778 ? dim_val
2779 : b.create<IndexCastOp>(
2780 loc, target_shape_type.getElementType(), dim_val);
2781 b.create<tensor::YieldOp>(loc, dim_val);
2782 });
2783 rewriter.replaceOp(op, gen.result());
2784
2785 return success();
2786 }
2787 };
2788
2789 struct CstrReshapableConversion
2790 : public OpConversionPattern<mhlo::CstrReshapableOp> {
2791 using OpConversionPattern<mhlo::CstrReshapableOp>::OpConversionPattern;
matchAndRewritemlir::__anonb93b5a210111::CstrReshapableConversion2792 LogicalResult matchAndRewrite(
2793 mhlo::CstrReshapableOp op, ArrayRef<Value> args,
2794 ConversionPatternRewriter& rewriter) const final {
2795 auto loc = op.getLoc();
2796 auto ctx = op->getContext();
2797 Value neg_one = rewriter.create<ConstantIndexOp>(loc, -1);
2798 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
2799 Value one = rewriter.create<ConstantIndexOp>(loc, 1);
2800 Value two = rewriter.create<ConstantIndexOp>(loc, 2);
2801 auto num_elements = args[0];
2802 auto target_shape_type = args[1].getType().cast<ShapedType>();
2803 auto extent_type =
2804 shape::getExtentTensorType(ctx, target_shape_type.getDimSize(0));
2805
2806 // Calculate the computed actual extent for a possible dynamic extent.
2807 auto new_shape =
2808 target_shape_type.getElementType().isIndex()
2809 ? args[1]
2810 : rewriter.create<IndexCastOp>(loc, extent_type, args[1]);
2811 auto reduction = rewriter.create<shape::ReduceOp>(
2812 loc, new_shape, llvm::makeArrayRef({one, zero, zero}));
2813 {
2814 PatternRewriter::InsertionGuard g(rewriter);
2815 auto body = reduction.getBody();
2816 rewriter.setInsertionPointToEnd(body);
2817 Value extent = body->getArgument(1);
2818 Value is_dynamic =
2819 rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, neg_one, extent);
2820 Value is_invalid =
2821 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, extent, neg_one);
2822 Value total_dynamic = rewriter.create<AddIOp>(
2823 loc, rewriter.create<SelectOp>(loc, is_dynamic, one, zero),
2824 body->getArgument(3));
2825 Value total_invalid = rewriter.create<AddIOp>(
2826 loc, rewriter.create<SelectOp>(loc, is_invalid, one, zero),
2827 body->getArgument(4));
2828 Value extent_or_one =
2829 rewriter.create<SelectOp>(loc, is_dynamic, one, extent);
2830 Value total_elements =
2831 rewriter.create<MulIOp>(loc, extent_or_one, body->getArgument(2));
2832 rewriter.create<shape::YieldOp>(
2833 loc,
2834 llvm::makeArrayRef({total_elements, total_dynamic, total_invalid}));
2835 }
2836 Value is_divisible = rewriter.create<CmpIOp>(
2837 loc, CmpIPredicate::eq, zero,
2838 rewriter.create<SignedRemIOp>(loc, num_elements,
2839 reduction->getResult(0)));
2840 Value acceptably_dynamic = rewriter.create<CmpIOp>(
2841 loc, CmpIPredicate::ult, two, reduction->getResult(1));
2842 Value no_invalid = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, zero,
2843 reduction->getResult(0));
2844
2845 Value all_passing = rewriter.create<AndOp>(
2846 loc, is_divisible,
2847 rewriter.create<AndOp>(loc, acceptably_dynamic, no_invalid));
2848
2849 rewriter.replaceOpWithNewOp<shape::CstrRequireOp>(
2850 op, all_passing, "Required valid reshape shape input");
2851
2852 return success();
2853 }
2854 };
2855
2856 // Converts LHLO ops to Linalg generic.
2857 // Sample result for lmhlo::AddOp.
2858 //
2859 // "lmhlo.add"(%arg1, %arg2, %out) :
2860 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
2861 //
2862 // will be converted to
2863 //
2864 // #map0 = (d0, d1) -> (d0, d1)
2865 // "linalg.generic"(%arg1, %arg2, %out) ( {
2866 // ^bb0(%arg4: f32, %arg5: f32):
2867 // %0 = addf %arg4, %arg5 : f32
2868 // "linalg.yield"(%0) : (f32) -> ()
2869 // }) {
2870 // indexing_maps = [#map0, #map0, #map0],
2871 // iterator_types = ["parallel", "parallel"],
2872 // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
2873 struct LhloLegalizeToLinalgPass
2874 : public lmhlo::LhloLegalizeToLinalgPassBase<LhloLegalizeToLinalgPass> {
getDependentDialectsmlir::__anonb93b5a210111::LhloLegalizeToLinalgPass2875 void getDependentDialects(DialectRegistry& registry) const override {
2876 registry
2877 .insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
2878 math::MathDialect, memref::MemRefDialect>();
2879 }
2880
runOnFunctionmlir::__anonb93b5a210111::LhloLegalizeToLinalgPass2881 void runOnFunction() override {
2882 OwningRewritePatternList patterns(&getContext());
2883 ConversionTarget target(getContext());
2884 target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
2885 math::MathDialect, memref::MemRefDialect,
2886 StandardOpsDialect, AffineDialect>();
2887 target.addLegalOp<UnrealizedConversionCastOp>();
2888
2889 mhlo::RemoveSignTypeConverter type_converter;
2890 auto func = getFunction();
2891 populateLHLOToLinalgConversionPattern(func.getContext(), type_converter,
2892 &patterns);
2893 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
2894 signalPassFailure();
2895 }
2896 }
2897 };
2898
2899 struct HloLegalizeToLinalgPass
2900 : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
getDependentDialectsmlir::__anonb93b5a210111::HloLegalizeToLinalgPass2901 void getDependentDialects(DialectRegistry& registry) const override {
2902 registry.insert<linalg::LinalgDialect, scf::SCFDialect,
2903 complex::ComplexDialect, math::MathDialect,
2904 memref::MemRefDialect, shape::ShapeDialect>();
2905 }
2906
runOnFunctionmlir::__anonb93b5a210111::HloLegalizeToLinalgPass2907 void runOnFunction() override {
2908 MLIRContext& ctx = getContext();
2909 OwningRewritePatternList patterns(&ctx);
2910 ConversionTarget target(ctx);
2911 target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
2912 math::MathDialect, StandardOpsDialect,
2913 tensor::TensorDialect, scf::SCFDialect,
2914 shape::ShapeDialect>();
2915
2916 target.addLegalOp<UnrealizedConversionCastOp>();
2917
2918 mhlo::RemoveSignTypeConverter type_converter;
2919 auto func = getFunction();
2920 mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns);
2921 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
2922 signalPassFailure();
2923 }
2924 }
2925 };
2926
2927 } // namespace
2928
2929 namespace lmhlo {
createLegalizeLhloToLinalgPass()2930 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
2931 return std::make_unique<LhloLegalizeToLinalgPass>();
2932 }
2933 } // namespace lmhlo
2934
2935 namespace mhlo {
2936
populateHLOToLinalgConversionPattern(MLIRContext * context,TypeConverter & type_converter,OwningRewritePatternList * patterns)2937 void populateHLOToLinalgConversionPattern(MLIRContext* context,
2938 TypeConverter& type_converter,
2939 OwningRewritePatternList* patterns) {
2940 // clang-format off
2941 patterns->insert<
2942 BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter,
2943 ConstConverterTensor, HloDynamicBroadcastInDimConverter,
2944 HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
2945 EinsumToLinalgConverter,
2946 IotaConverter<mhlo::DynamicIotaOp, false>,
2947 PointwiseToLinalgConverter<mhlo::AbsOp, false>,
2948 PointwiseToLinalgConverter<mhlo::AddOp, false>,
2949 PointwiseToLinalgConverter<mhlo::AndOp, false>,
2950 PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
2951 PointwiseToLinalgConverter<mhlo::CeilOp, false>,
2952 PointwiseToLinalgConverter<mhlo::ClampOp, false>,
2953 PointwiseToLinalgConverter<mhlo::CompareOp, false>,
2954 PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
2955 PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
2956 PointwiseToLinalgConverter<mhlo::CopyOp, false>,
2957 PointwiseToLinalgConverter<mhlo::CosOp, false>,
2958 PointwiseToLinalgConverter<mhlo::DivOp, false>,
2959 PointwiseToLinalgConverter<mhlo::ExpOp, false>,
2960 PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
2961 PointwiseToLinalgConverter<mhlo::FloorOp, false>,
2962 PointwiseToLinalgConverter<mhlo::ImagOp, false>,
2963 PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
2964 PointwiseToLinalgConverter<mhlo::LogOp, false>,
2965 PointwiseToLinalgConverter<mhlo::LogisticOp, false>,
2966 PointwiseToLinalgConverter<mhlo::Log1pOp, false>,
2967 PointwiseToLinalgConverter<mhlo::MaxOp, false>,
2968 PointwiseToLinalgConverter<mhlo::MinOp, false>,
2969 PointwiseToLinalgConverter<mhlo::MulOp, false>,
2970 PointwiseToLinalgConverter<mhlo::NegOp, false>,
2971 PointwiseToLinalgConverter<mhlo::NotOp, false>,
2972 PointwiseToLinalgConverter<mhlo::OrOp, false>,
2973 PointwiseToLinalgConverter<mhlo::PowOp, false>,
2974 PointwiseToLinalgConverter<mhlo::RealOp, false>,
2975 PointwiseToLinalgConverter<mhlo::RemOp, false>,
2976 PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
2977 PointwiseToLinalgConverter<mhlo::SelectOp, false>,
2978 PointwiseToLinalgConverter<mhlo::ShiftLeftOp, false>,
2979 PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp, false>,
2980 PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp, false>,
2981 PointwiseToLinalgConverter<mhlo::SignOp, false>,
2982 PointwiseToLinalgConverter<mhlo::SinOp, false>,
2983 PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
2984 PointwiseToLinalgConverter<mhlo::SubOp, false>,
2985 PointwiseToLinalgConverter<mhlo::TanhOp, false>,
2986 PointwiseToLinalgConverter<mhlo::XorOp, false>,
2987 ReshapeOpConverter<mhlo::ReshapeOp, false>,
2988 ReverseConverter<mhlo::ReverseOp, false>,
2989 SliceConverter<mhlo::SliceOp, false>,
2990 ComputeReshapeShapeConversion,
2991 CstrReshapableConversion,
2992 DynamicSliceConverter,
2993 DynamicUpdateSliceConverter,
2994 TransposeConverter<mhlo::TransposeOp, false>,
2995 DotOpOnTensorsConversion<DotOperationType::kMatrixMatrix,
2996 linalg::MatmulOp>,
2997 DotOpOnTensorsConversion<DotOperationType::kMatrixVector,
2998 linalg::MatvecOp>,
2999 DotOpOnTensorsConversion<DotOperationType::kVectorDot, linalg::DotOp>,
3000 DotGeneralOpOnTensorsConversion,
3001 NormalConvOpOnTensorsConversion,
3002 DepthwiseConvOpOnTensorsConversion,
3003 ReduceOnTensorsConversion,
3004 ReduceWindowOpOnTensorsConversion,
3005 ScatterUpdateOnTensorsConversion,
3006 TorchIndexSelectOpOnTensorsConversion,
3007 PadOpOnTensorsConversion>(type_converter, context);
3008 // clang-format on
3009 patterns->insert<ReduceRegionXLAOpConversion<mhlo::AddOp>,
3010 ReduceRegionXLAOpConversion<mhlo::AndOp>,
3011 ReduceRegionXLAOpConversion<mhlo::CompareOp>,
3012 ReduceRegionXLAOpConversion<mhlo::MaxOp>,
3013 ReduceRegionXLAOpConversion<mhlo::MinOp>,
3014 ReduceRegionXLAOpConversion<mhlo::MulOp>,
3015 ReduceRegionXLAOpConversion<mhlo::OrOp>,
3016 ReduceRegionXLAOpConversion<mhlo::SelectOp>,
3017 ReduceRegionReturnOpConversion>(context,
3018 PatternBenefit(1000));
3019 }
3020
createLegalizeHloToLinalgPass()3021 std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
3022 return std::make_unique<HloLegalizeToLinalgPass>();
3023 }
3024
createHloToLinalgSignedIntegerConverter()3025 std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() {
3026 return std::make_unique<RemoveSignTypeConverter>();
3027 }
3028
3029 } // namespace mhlo
3030 } // namespace mlir
3031