• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 
15 ==============================================================================*/
16 
17 #include <utility>
18 
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir/Dialect/SCF/SCF.h"
30 #include "mlir/Dialect/Shape/IR/Shape.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/Block.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinOps.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Interfaces/InferTypeOpInterface.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43 
44 namespace mlir {
45 
46 /// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
47 /// `mlir::Value`s.
operator <(const Value & lhs,const Value & rhs)48 static bool operator<(const Value &lhs, const Value &rhs) {
49   return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
50 }
51 
52 namespace mhlo {
53 namespace {
54 
55 /// Identify clusters of operations that can be rank-specialized together. The
56 /// required traits for clustered operations are:
57 ///   - Element-wise: All operations in the group must be element-wise. This
58 ///     allows to reshape operands before applying the operations as well as
59 ///     reshaping the result to the desired shape afterwards. This way, we can,
60 ///     e.g., apply unary ops to a completely flattened operand and restore the
61 ///     original shape afterwards.
62 ///   - Broadcasting semantics: All operations must implement broadcasting
63 ///     semantics. Most importantly, this allows extending operand shapes such
64 ///     that they match in rank. Operations that require all their operands to
65 ///     be of the same shape also fulfill this requirement.
66 ///   - Shape reification: All operations must implement
67 ///     `InferShapedTypeOpInterface`. This is later needed to compute and to
68 ///     restore the desired result shape.
69 
IsClusterable(Operation * op)70 bool IsClusterable(Operation *op) {
71   if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
72   if (op->getNumOperands() == 0) return false;
73   return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
74           op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
75          op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
76 }
77 
78 struct RankSpecializationClusterPattern : public RewritePattern {
RankSpecializationClusterPatternmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPattern79   explicit RankSpecializationClusterPattern(MLIRContext *ctx)
80       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
81 
matchAndRewritemlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPattern82   LogicalResult matchAndRewrite(Operation *op,
83                                 PatternRewriter &rewriter) const override {
84     // Only apply to operations that have not been clustered yet.
85     if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
86       return failure();
87     }
88 
89     // Only cluster when rank specialization is needed.
90     if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
91           return ty.isa<UnrankedTensorType>();
92         })) {
93       return failure();
94     }
95 
96     // Collect all collectively rank specializable ops.
97     SmallVector<Operation *, 16> cluster;
98     llvm::SmallSet<Value, 16> operand_set;
99     llvm::SmallSet<Value, 16> result_set;
100 
101     Operation *root_op = op;
102     while (root_op->getNextNode() != nullptr &&
103            IsClusterable(root_op->getNextNode()))
104       root_op = root_op->getNextNode();
105 
106     Operation *it = root_op;
107     while (it != nullptr && IsClusterable(it)) {
108       // Find results that escape the cluster.
109       for (OpOperand &use : it->getUses()) {
110         if (!llvm::is_contained(cluster, use.getOwner()))
111           result_set.insert(use.get());
112       }
113 
114       // Update cluster operands.
115       for (OpResult v : it->getResults()) operand_set.erase(Value(v));
116       for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
117 
118       cluster.push_back(it);
119       it = it->getPrevNode();
120     }
121 
122     // Create `RankSpecializationClusterOp`.
123     auto operands = llvm::to_vector<16>(operand_set);
124     auto results = llvm::to_vector<16>(result_set);
125     auto result_types = llvm::to_vector<16>(
126         llvm::map_range(result_set, [](Value v) { return v.getType(); }));
127     Location loc = op->getLoc();
128     auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
129         loc, result_types, operands);
130 
131     // Create body block.
132     auto operand_types = llvm::to_vector<16>(
133         llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
134     Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
135 
136     // Copy operations into the body.
137     BlockAndValueMapping bvm;
138     for (auto it : llvm::zip(operands, block->getArguments()))
139       bvm.map(std::get<0>(it), std::get<1>(it));
140     rewriter.setInsertionPointToStart(block);
141     for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
142 
143     // Create `RankSpecializationClusterYieldOp`.
144     auto mapped_results = llvm::to_vector<16>(
145         llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
146     rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
147                                                             mapped_results);
148 
149     // Replace original ops with the new results.
150     for (auto it : llvm::zip(results, cluster_op.results()))
151       bvm.map(std::get<0>(it), std::get<1>(it));
152     for (Operation *it : cluster) {
153       if (it->getUses().empty()) {
154         rewriter.eraseOp(it);
155         continue;
156       }
157       auto replacements = llvm::to_vector<16>(llvm::map_range(
158           it->getResults(), [&](Value v) { return bvm.lookup(v); }));
159       rewriter.replaceOp(it, replacements);
160     }
161 
162     return success();
163   }
164 };
165 
166 struct MergeRankSpecializationClusterOpsPattern
167     : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
168   using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
169 
matchAndRewritemlir::mhlo::__anonf2c9eda60111::MergeRankSpecializationClusterOpsPattern170   LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
171                                 PatternRewriter &rewriter) const override {
172     auto preceding_op =
173         llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
174             op->getPrevNode());
175     if (!preceding_op) return failure();
176     Block *body = op.getBody();
177     Block *preceding_body = preceding_op.getBody();
178     auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
179         op.getBody()->getTerminator());
180     auto preceding_yield_op =
181         llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
182             preceding_op.getBody()->getTerminator());
183 
184     // Merge cluster operands. Consider only those operands of the second
185     // cluster that do not originate in the preceding cluster.
186     SmallVector<Value, 8> new_operands;
187     for (Value v : preceding_op.operands()) new_operands.push_back(v);
188     for (Value v : op.operands()) {
189       if (v.getDefiningOp() != preceding_op &&
190           !llvm::is_contained(preceding_op.operands(), v)) {
191         new_operands.push_back(v);
192       }
193     }
194 
195     // Merge cluster results. Consider only those results of the preceding
196     // cluster that are not exclusively used as operands to the second cluster.
197     SmallVector<Value, 8> new_unmapped_results;
198     for (auto it :
199          llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
200       Value result, inner_result;
201       std::tie(result, inner_result) = it;
202       if (!llvm::all_of(result.getUsers(),
203                         [&](Operation *user) { return user == op; })) {
204         new_unmapped_results.push_back(inner_result);
205       }
206     }
207     for (Value v : yield_op.results()) new_unmapped_results.push_back(v);
208 
209     // Create merged cluster op.
210     rewriter.setInsertionPoint(preceding_op);
211     auto loc = op.getLoc();
212     auto result_types = llvm::to_vector<16>(llvm::map_range(
213         new_unmapped_results, [](Value v) { return v.getType(); }));
214     auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
215         loc, result_types, new_operands);
216     auto operand_types = llvm::to_vector<16>(
217         llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
218     Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types);
219     rewriter.setInsertionPointToStart(new_body);
220 
221     // Map operands and copy operations of the preceding cluster into the new
222     // body.
223     BlockAndValueMapping bvm;
224     for (auto it : llvm::enumerate(preceding_body->getArguments()))
225       bvm.map(it.value(), new_body->getArgument(it.index()));
226     for (Operation &nested_op : preceding_body->without_terminator())
227       rewriter.clone(nested_op, bvm);
228 
229     // Map operands and copy operations of the second cluster. If they result
230     // from the preceeding cluster, we can simply map the corresponding value
231     // internally.
232     for (auto it : llvm::zip(body->getArguments(), op.operands())) {
233       Value block_arg, operand;
234       std::tie(block_arg, operand) = it;
235       if (operand.getDefiningOp() == preceding_op) {
236         auto where = llvm::find(preceding_op.results(), operand);
237         assert(where.getBase() != nullptr && "expected to find ");
238         bvm.map(block_arg,
239                 bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
240       } else {
241         auto where = llvm::find(new_op.operands(), operand);
242         bvm.map(block_arg, new_body->getArgument(where.getIndex()));
243       }
244     }
245     for (Operation &nested_op : body->without_terminator()) {
246       rewriter.clone(nested_op, bvm);
247     }
248 
249     // Yield inner results.
250     rewriter.create<chlo::RankSpecializationClusterYieldOp>(
251         loc,
252         llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
253           return bvm.lookupOrDefault(v);
254         })));
255 
256     // Replace the two cluster ops with the new corresponding results.
257     SmallVector<Value, 8> preceding_op_replacements;
258     int64_t i = 0;
259     for (Value result : preceding_op.results()) {
260       Value replacement = nullptr;
261       if (!llvm::all_of(result.getUsers(),
262                         [&](Operation *user) { return user == op; })) {
263         replacement = new_op->getResult(i++);
264       }
265       preceding_op_replacements.push_back(replacement);
266     }
267     ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
268     rewriter.replaceOp(op, op_replacements);
269     rewriter.replaceOp(preceding_op, preceding_op_replacements);
270 
271     return success();
272   }
273 };
274 
275 struct RankSpecializationClusterPass
276     : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
getDependentDialectsmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPass277   void getDependentDialects(DialectRegistry &registry) const override {
278     registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
279   }
280 
runOnFunctionmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPass281   void runOnFunction() override {
282     MLIRContext *ctx = &getContext();
283     RewritePatternSet patterns(ctx);
284     mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
285     if (failed(
286             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
287       return signalPassFailure();
288     }
289   }
290 };
291 
292 /// Lower rank specialization cluster to SCF.
293 
IsScalarTensorType(Type ty)294 bool IsScalarTensorType(Type ty) {
295   auto ranked_ty = ty.dyn_cast<RankedTensorType>();
296   return ranked_ty && ranked_ty.getRank() == 0;
297 }
298 
IsScalarShapeType(Type ty)299 bool IsScalarShapeType(Type ty) {
300   return ty.cast<RankedTensorType>().getDimSize(0) == 0;
301 }
302 
DeriveRankedTensorTypes(Type ty,int64_t rank)303 Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
304   auto tensor_ty = ty.dyn_cast<TensorType>();
305   if (!tensor_ty) return ty;
306   SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
307   return RankedTensorType::get(shape, tensor_ty.getElementType());
308 }
309 
DeriveUnrankedTensorTypes(Type ty)310 Type DeriveUnrankedTensorTypes(Type ty) {
311   if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
312     return UnrankedTensorType::get(ranked_ty.getElementType());
313   return ty;
314 }
315 
MaterializeRankedOperations(OpBuilder & b,Location loc,BlockAndValueMapping & bvm,chlo::RankSpecializationClusterOp op)316 SmallVector<Value, 8> MaterializeRankedOperations(
317     OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
318     chlo::RankSpecializationClusterOp op) {
319   // Create ranked operations.
320   for (Operation &nested_op : op.getBody()->without_terminator()) {
321     auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
322         nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
323     int64_t target_rank = 0;
324     for (Value v : mapped_operands) {
325       target_rank =
326           std::max(target_rank, v.getType().cast<RankedTensorType>().getRank());
327     }
328     auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
329         nested_op.getResultTypes(),
330         [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
331     OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
332                                    mapped_operands, ranked_result_types,
333                                    nested_op.getAttrs());
334     Operation *ranked_op = b.createOperation(ranked_op_state);
335     for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults()))
336       bvm.map(std::get<0>(it), std::get<1>(it));
337   }
338 
339   // Collect ranked results.
340   auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
341       op.getBody()->getTerminator());
342   return llvm::to_vector<8>(llvm::map_range(
343       yield_op.results(), [&](Value v) { return bvm.lookup(v); }));
344 }
345 
MaterializeFinalReshape(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange unshaped_results)346 SmallVector<Value, 8> MaterializeFinalReshape(
347     PatternRewriter &rewriter, Location loc,
348     chlo::RankSpecializationClusterOp op, ValueRange unshaped_results) {
349   auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
350       op.getBody()->getTerminator());
351   assert(unshaped_results.size() == 1 && yield_op.results().size() == 1 &&
352          "Currently, rank specialization supports only one result.");
353 
354   // Reify result shape.
355   Operation *last_op_before_shape_reification = op->getPrevNode();
356   SmallVector<Value, 1> result_shape;
357   Value original_result = yield_op.results().front();
358   auto original_result_iface =
359       llvm::cast<InferShapedTypeOpInterface>(original_result.getDefiningOp());
360   if (failed(original_result_iface.reifyReturnTypeShapes(
361           rewriter, original_result_iface->getOperands(), result_shape))) {
362     return {};
363   }
364 
365   // Materialize final reshape.
366   Value unshaped_result = unshaped_results.front();
367   Value result = rewriter.create<mhlo::DynamicReshapeOp>(
368       loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
369       unshaped_result, result_shape.front());
370 
371   // Reify shapes until they are independent of operations in the original
372   // cluster.
373   {
374     Operation *it = result_shape.front().getDefiningOp();
375     while (it != nullptr && it != last_op_before_shape_reification) {
376       bool advanced = false;
377       if (auto shape_of_op = llvm::dyn_cast<shape::ShapeOfOp>(it)) {
378         Operation *def = shape_of_op.arg().getDefiningOp();
379         if (def && def->getBlock() == op.getBody()) {
380           // Resolve `shape_of` op because it still depends on operation in the
381           // original cluster.
382           OpBuilder::InsertionGuard guard(rewriter);
383           rewriter.setInsertionPoint(shape_of_op);
384           SmallVector<Value, 1> tmp_shape;
385           auto iface = llvm::cast<InferShapedTypeOpInterface>(def);
386           if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(),
387                                                  tmp_shape)))
388             return {};
389           rewriter.replaceOp(shape_of_op, tmp_shape.front());
390 
391           // Continue, including the newly created operations.
392           it = tmp_shape.front().getDefiningOp();
393           advanced = true;
394         }
395       }
396 
397       // Skip op, otherwise.
398       if (!advanced) it = it->getPrevNode();
399     }
400   }
401 
402   // Replace all remaining uses of the original cluster's block args.
403   for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) {
404     Value operand, barg;
405     std::tie(operand, barg) = it;
406     barg.replaceUsesWithIf(operand, [&](OpOperand &operand) {
407       return operand.getOwner()->getBlock() != op.getBody();
408     });
409   }
410 
411   return {result};
412 }
413 
MaterializeFlatShape(OpBuilder & b,Location loc,ValueRange same_shapes)414 Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) {
415   assert(!same_shapes.empty() && "Expected at least one shape.");
416   Value shape = same_shapes.size() == 1
417                     ? same_shapes.front()
418                     : b.create<shape::AnyOp>(loc, same_shapes.front().getType(),
419                                              same_shapes);
420   return b.create<tensor::FromElementsOp>(
421       loc,
422       b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).result());
423 }
424 
MaterializeScalarRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,ValueRange non_scalars_of_same_shape,function_ref<void (OpBuilder &,Location)> else_builder_fn)425 Value MaterializeScalarRankSpecializationCase(
426     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
427     const SmallVector<Value, 8> &shapes, ValueRange non_scalars_of_same_shape,
428     function_ref<void(OpBuilder &, Location)> else_builder_fn) {
429   // Materialize predicate: All operands are scalars, except the expected
430   // non-scalars.
431   Value one = b.create<ConstantIndexOp>(loc, 1);
432   Value all_others_are_scalar;
433   for (auto it : llvm::zip(op.operands(), shapes)) {
434     Value operand, shape;
435     std::tie(operand, shape) = it;
436     if (llvm::is_contained(non_scalars_of_same_shape, operand) ||
437         IsScalarTensorType(operand.getType())) {
438       continue;
439     }
440     auto literal =
441         b.create<CmpIOp>(loc, CmpIPredicate::eq,
442                          b.create<shape::NumElementsOp>(loc, shape), one);
443     all_others_are_scalar =
444         all_others_are_scalar
445             ? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
446                   .getResult()
447             : literal.result();
448   }
449 
450   auto if_op = b.create<scf::IfOp>(
451       loc, op->getResultTypes(), all_others_are_scalar,
452       [&](OpBuilder &b, Location loc) {
453         // Compute flat non-scalar shape.
454         SmallVector<Value, 4> non_scalar_shapes;
455         for (auto it : llvm::zip(op.operands(), shapes)) {
456           Value operand, shape;
457           std::tie(operand, shape) = it;
458           if (llvm::is_contained(non_scalars_of_same_shape, operand))
459             non_scalar_shapes.push_back(shape);
460         }
461         Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
462 
463         // Derive ranked operands.
464         auto ranked_operands =
465             llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
466               if (IsScalarTensorType(v.getType())) return v;
467               if (!llvm::is_contained(non_scalars_of_same_shape, v)) {
468                 return b
469                     .create<mhlo::ReshapeOp>(
470                         loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0),
471                         v)
472                     .getResult();
473               }
474               return b
475                   .create<mhlo::DynamicReshapeOp>(
476                       loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
477                       flat_shape)
478                   .getResult();
479             }));
480 
481         // Materialize ranked variants for the element-wise operations.
482         BlockAndValueMapping bvm;
483         for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
484           bvm.map(std::get<0>(it), std::get<1>(it));
485         Value unshaped_result =
486             MaterializeRankedOperations(b, loc, bvm, op).front();
487 
488         // Return as unranked tensor for compatibility with the other cases.
489         b.create<scf::YieldOp>(
490             loc, b.create<tensor::CastOp>(
491                       loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
492                       unshaped_result)
493                      .dest());
494       },
495       else_builder_fn);
496 
497   return if_op.results().front();
498 }
499 
MaterializeEqualShapesRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,function_ref<void (OpBuilder &,Location)> else_builder_fn)500 Value MaterializeEqualShapesRankSpecializationCase(
501     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
502     const SmallVector<Value, 8> &shapes,
503     function_ref<void(OpBuilder &, Location)> else_builder_fn) {
504   // Materialize all shapes equal predicate.
505   Value all_shapes_eq_or_scalar;
506   auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
507       shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
508   assert(
509       non_scalar_shapes.size() >= 2 &&
510       "Equal shapes strategy requires at least two non-scalar operand shapes.");
511   for (Value s : llvm::drop_begin(non_scalar_shapes)) {
512     auto literal =
513         b.create<shape::ShapeEqOp>(loc, non_scalar_shapes.front(), s);
514     all_shapes_eq_or_scalar =
515         all_shapes_eq_or_scalar
516             ? b.create<mlir::AndOp>(loc, all_shapes_eq_or_scalar, literal)
517                   .result()
518             : literal;
519   }
520 
521   auto if_op = b.create<scf::IfOp>(
522       loc, op->getResultTypes(), all_shapes_eq_or_scalar,
523       [&](OpBuilder &b, Location loc) {
524         // Flatten non-scalar operands.
525         Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
526         auto flat_operands =
527             llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
528               if (IsScalarTensorType(v.getType())) return v;
529               return b
530                   .create<mhlo::DynamicReshapeOp>(
531                       loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
532                       flat_shape)
533                   .result();
534             }));
535 
536         // Materialize ranked variants for the element-wise operations.
537         BlockAndValueMapping bvm;
538         for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands))
539           bvm.map(std::get<0>(it), std::get<1>(it));
540         Value unshaped_result =
541             MaterializeRankedOperations(b, loc, bvm, op).front();
542 
543         // Return as unranked tensor for compatibility with the other cases.
544         b.create<scf::YieldOp>(
545             loc, b.create<tensor::CastOp>(
546                       loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
547                       unshaped_result)
548                      .dest());
549       },
550       else_builder_fn);
551 
552   return if_op.results().front();
553 }
554 
MaterializeTargetRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t target_rank)555 Value MaterializeTargetRankSpecializationCase(
556     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
557     const SmallVector<Value, 8> &shapes, int64_t target_rank) {
558   // Reshape unranked operands to match the target rank.
559   RankedTensorType extent_tensor_ty =
560       shape::getExtentTensorType(b.getContext(), target_rank);
561   Value all_ones_shape = b.create<shape::ConstShapeOp>(
562       loc, extent_tensor_ty,
563       mlir::DenseIntElementsAttr::get(extent_tensor_ty,
564                                       SmallVector<int64_t, 6>(target_rank, 1)));
565   SmallVector<Value, 8> ranked_operands;
566   for (auto it : llvm::zip(op.operands(), shapes)) {
567     Value operand, shape;
568     std::tie(operand, shape) = it;
569     if (operand.getType().isa<RankedTensorType>()) {
570       ranked_operands.push_back(operand);
571       continue;
572     }
573     Value ranked_shape = b.create<tensor::CastOp>(
574         loc, extent_tensor_ty,
575         b.create<shape::BroadcastOp>(loc,
576                                      shape::getExtentTensorType(b.getContext()),
577                                      shape, all_ones_shape,
578                                      /*error=*/nullptr));
579     ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>(
580         loc, DeriveRankedTensorTypes(operand.getType(), target_rank), operand,
581         ranked_shape));
582   }
583 
584   // Materialize ranked versions of the element-wise operations.
585   BlockAndValueMapping bvm;
586   for (auto it : llvm::zip(op.body().front().getArguments(), ranked_operands))
587     bvm.map(std::get<0>(it), std::get<1>(it));
588 
589   // Return as unranked for compatibility with other target ranks.
590   auto unshaped_result = MaterializeRankedOperations(b, loc, bvm, op).front();
591   return b.create<tensor::CastOp>(
592       loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
593       unshaped_result);
594 }
595 
RecusivelyMaterializeTargetRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,Value max_rank,int64_t min_target_rank,int64_t max_target_rank)596 Value RecusivelyMaterializeTargetRankSpecializationCases(
597     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
598     const SmallVector<Value, 8> &shapes, Value max_rank,
599     int64_t min_target_rank, int64_t max_target_rank) {
600   Value condition =
601       b.create<CmpIOp>(loc, CmpIPredicate::ule, max_rank,
602                        b.create<ConstantIndexOp>(loc, min_target_rank));
603 
604   // If only a unique target rank is left, we can lower to an assert instead
605   // of the usual if operation.
606   if (min_target_rank == max_target_rank) {
607     b.create<AssertOp>(loc, condition,
608                        "Input for dynamic binary or n-ary op lowering was of "
609                        "a rank greater than " +
610                            std::to_string(max_target_rank));
611     return MaterializeTargetRankSpecializationCase(b, loc, op, shapes,
612                                                    min_target_rank);
613   }
614 
615   // Materialize IR for the smallest considered target rank.
616   auto if_op = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
617                                    /*withElseRegion=*/true);
618   auto then_builder = if_op.getThenBodyBuilder();
619   then_builder.create<scf::YieldOp>(
620       loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,
621                                                    shapes, min_target_rank));
622 
623   // Recurse for all remaining target ranks.
624   auto else_builder = if_op.getElseBodyBuilder();
625   else_builder.create<scf::YieldOp>(
626       loc, RecusivelyMaterializeTargetRankSpecializationCases(
627                else_builder, loc, op, shapes, max_rank, min_target_rank + 1,
628                max_target_rank));
629 
630   return if_op.results().front();
631 }
632 
MaterializeGenericRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t max_target_rank)633 Value MaterializeGenericRankSpecializationCases(
634     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
635     const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
636   // Get the minimum broadcast shapes of the operands.
637   auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
638       shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
639   auto min_bcast_shapes_op = b.create<chlo::MinimumBroadcastShapesOp>(
640       loc,
641       SmallVector<Type, 8>(non_scalar_shapes.size(),
642                            shape::getExtentTensorType(b.getContext())),
643       non_scalar_shapes);
644 
645   // Find the maximum rank among the reduced operand shapes.
646   Value max_rank;
647   for (Value shape : min_bcast_shapes_op.results()) {
648     Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
649     if (!max_rank) {
650       max_rank = rank;
651     } else {
652       max_rank = b.create<mlir::SelectOp>(
653           loc, b.create<CmpIOp>(loc, CmpIPredicate::sgt, max_rank, rank),
654           max_rank, rank);
655     }
656   }
657 
658   // Collect reduced shapes.
659   SmallVector<Value, 8> reduced_shapes;
660   auto it = min_bcast_shapes_op.result_begin();
661   for (Value s : shapes) {
662     if (IsScalarShapeType(s.getType())) {
663       reduced_shapes.push_back(s);
664     } else {
665       reduced_shapes.push_back(*it++);
666     }
667   }
668 
669   // Materialize rank specialization for ranks 1, ...
670   return RecusivelyMaterializeTargetRankSpecializationCases(
671       b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1,
672       max_target_rank);
673 }
674 
MaterializeDefaultRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t max_target_rank)675 Value MaterializeDefaultRankSpecializationCases(
676     OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
677     const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
678   return MaterializeEqualShapesRankSpecializationCase(
679       b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
680         b.create<scf::YieldOp>(loc, MaterializeGenericRankSpecializationCases(
681                                         b, loc, op, shapes, max_target_rank));
682       });
683 }
684 
685 SmallVector<Value, 8>
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange non_scalars_of_same_shape)686 MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
687     PatternRewriter &rewriter, Location loc,
688     chlo::RankSpecializationClusterOp op,
689     ValueRange non_scalars_of_same_shape) {
690   // Compute flat operand shape.
691   auto non_scalar_shapes = llvm::to_vector<4>(
692       llvm::map_range(non_scalars_of_same_shape, [&](Value v) {
693         return rewriter.create<shape::ShapeOfOp>(loc, v).result();
694       }));
695   Value flat_shape = MaterializeFlatShape(rewriter, loc, non_scalar_shapes);
696 
697   // Materialize ranked variants for the element-wise operations.
698   BlockAndValueMapping bvm;
699   for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
700     Value operand;
701     Value bb_arg;
702     std::tie(bb_arg, operand) = it;
703     if (!IsScalarTensorType(operand.getType())) {
704       assert(llvm::is_contained(non_scalars_of_same_shape, operand) &&
705              "Expected all non-scalars in the same shape equivalence class.");
706       operand = rewriter.create<mhlo::DynamicReshapeOp>(
707           loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
708           flat_shape);
709     }
710     bvm.map(bb_arg, operand);
711   }
712   SmallVector<Value, 8> unshaped_results =
713       MaterializeRankedOperations(rewriter, loc, bvm, op);
714 
715   // Restore the results' expected shape.
716   Value shape = non_scalar_shapes.front();
717   return llvm::to_vector<8>(llvm::map_range(unshaped_results, [&](Value v) {
718     return rewriter
719         .create<mhlo::DynamicReshapeOp>(
720             loc, DeriveUnrankedTensorTypes(v.getType()), v, shape)
721         .result();
722   }));
723 }
724 
MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,SmallVector<SmallVector<Value,4>,4> non_scalar_eqs,int64_t max_target_rank)725 Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
726     PatternRewriter &rewriter, Location loc,
727     chlo::RankSpecializationClusterOp op,
728     SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs,
729     int64_t max_target_rank) {
730   assert(non_scalar_eqs.size() == 2 &&
731          "Expect two non-scalar equivalence classes.");
732   auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
733     return rewriter.create<shape::ShapeOfOp>(loc, v).result();
734   }));
735   ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0];
736   ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1];
737 
738   // Materialize all the different cases.
739   Value unshaped_result = MaterializeScalarRankSpecializationCase(
740       rewriter, loc, op, shapes, rhs_non_scalar_eqs,
741       [&](OpBuilder &b, Location loc) {
742         b.create<scf::YieldOp>(
743             loc, MaterializeScalarRankSpecializationCase(
744                      b, loc, op, shapes, lhs_non_scalar_eqs,
745                      [&](OpBuilder &b, Location loc) {
746                        b.create<scf::YieldOp>(
747                            loc, MaterializeDefaultRankSpecializationCases(
748                                     b, loc, op, shapes, max_target_rank));
749                      }));
750       });
751 
752   // Materialize final reshape once and for all rank specialization cases.
753   return MaterializeFinalReshape(rewriter, loc, op, unshaped_result).front();
754 }
755 
756 // Materialize rank generic rank specialization.
MaterializeDefaultRankSpecialization(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,int64_t max_target_rank)757 Value MaterializeDefaultRankSpecialization(PatternRewriter &rewriter,
758                                            Location loc,
759                                            chlo::RankSpecializationClusterOp op,
760                                            int64_t max_target_rank) {
761   auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
762     return rewriter.create<shape::ShapeOfOp>(loc, v).result();
763   }));
764 
765   // Materialize all the different cases.
766   Value unshaped_result = MaterializeDefaultRankSpecializationCases(
767       rewriter, loc, op, shapes, max_target_rank);
768 
769   // Materialize final reshape once and for all rank specialization cases.
770   return MaterializeFinalReshape(rewriter, loc, op, unshaped_result).front();
771 }
772 
773 // This is a very limited form of shape inference. It is correct but incomplete.
FindNonScalarShapeEquivalences(chlo::RankSpecializationClusterOp op)774 SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
775     chlo::RankSpecializationClusterOp op) {
776   llvm::EquivalenceClasses<Value> eqs;
777 
778   // Bridge the equivalences between operands and block arguments.
779   for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
780     eqs.unionSets(std::get<0>(it), std::get<1>(it));
781 
782   // Find equalities through `SameOperandsAndResultShape` trait.
783   auto union_sets = [&](ValueRange vs) {
784     if (vs.empty()) return;
785     Value repr = vs.front();
786     for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
787   };
788   for (Operation &nested_op : op.getBody()->without_terminator()) {
789     if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
790       union_sets(nested_op.getOperands());
791       union_sets(nested_op.getResults());
792       if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
793         eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
794     }
795   }
796 
797   // Find shape equalities through surrounding constraints.
798   if (auto assuming_op = op->getParentOfType<shape::AssumingOp>()) {
799     SmallVector<Operation *, 8> queue;
800     auto append_if_not_null = [&](Operation *op) {
801       if (op != nullptr) queue.push_back(op);
802     };
803     append_if_not_null(assuming_op.witness().getDefiningOp());
804     while (!queue.empty()) {
805       Operation *it = queue.pop_back_val();
806       if (auto assuming_all_op = llvm::dyn_cast<shape::AssumingAllOp>(it)) {
807         for (Value v : assuming_all_op.inputs())
808           append_if_not_null(v.getDefiningOp());
809       } else if (auto cstr_eq_op = llvm::dyn_cast<shape::CstrEqOp>(it)) {
810         Value ref_arg;
811         for (Value v : cstr_eq_op.shapes()) {
812           if (auto shape_of_op =
813                   dyn_cast_or_null<shape::ShapeOfOp>(v.getDefiningOp())) {
814             if (!ref_arg) {
815               ref_arg = shape_of_op.arg();
816             } else {
817               eqs.unionSets(ref_arg, shape_of_op.arg());
818             }
819           }
820         }
821       }
822     }
823   }
824 
825   // Find equalities through special knowledge of ops.
826   // TODO(frgossen): Remove this when these shape equalities can be inferred
827   // from surrounding shape constraints.
828   for (Operation &nested_op : op.getBody()->without_terminator()) {
829     if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
830       union_sets(
831           {select_op.on_true(), select_op.on_false(), select_op.getResult()});
832     } else if (auto clamp_op = llvm::dyn_cast<mhlo::ClampOp>(nested_op)) {
833       union_sets({clamp_op.operand(), clamp_op.getResult()});
834     }
835   }
836 
837   // Convert to a list-like equivalence class representation.
838   SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs;
839   for (Value v : op.operands()) {
840     if (IsScalarTensorType(v.getType())) continue;
841     bool inserted = false;
842     for (auto &eq_class : non_scalar_eqs) {
843       if (eqs.isEquivalent(eq_class.front(), v)) {
844         eq_class.push_back(v);
845         inserted = true;
846         break;
847       }
848     }
849     if (!inserted) non_scalar_eqs.push_back(SmallVector<Value, 4>({v}));
850   }
851 
852   return non_scalar_eqs;
853 }
854 
855 struct LowerRankSpecializationClusterPattern
856     : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
LowerRankSpecializationClusterPatternmlir::mhlo::__anonf2c9eda60111::LowerRankSpecializationClusterPattern857   LowerRankSpecializationClusterPattern(MLIRContext *ctx,
858                                         int64_t max_target_rank)
859       : OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
860         max_target_rank(max_target_rank) {}
861 
matchAndRewritemlir::mhlo::__anonf2c9eda60111::LowerRankSpecializationClusterPattern862   LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
863                                 PatternRewriter &rewriter) const override {
864     // Restoring the result shape currently relies on all operands being used
865     // for a single result. The result shape is then the broadcasted shape of
866     // all operands.
867     if (op.getNumResults() != 1) return failure();
868 
869     // If there is only a single non-scalar shape equivalence class, we can
870     // flatten that operands completely.
871     SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs =
872         FindNonScalarShapeEquivalences(op);
873     Location loc = op.getLoc();
874     if (non_scalar_eqs.size() == 1) {
875       rewriter.replaceOp(
876           op,
877           MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
878               rewriter, loc, op, non_scalar_eqs.front()));
879       return success();
880     }
881 
882     // If there are exactly two non-scalar shape equivalence classes, we can
883     // consider two extra cases: If either of the operand classes turns out to
884     // be all-scalars at runtime, we can, again, flatten all operands.
885     if (non_scalar_eqs.size() == 2) {
886       rewriter.replaceOp(
887           op,
888           MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
889               rewriter, loc, op, non_scalar_eqs, max_target_rank));
890       return success();
891     }
892 
893     // For all other cases, reshape the operands to match in rank, apply the
894     // operation, and restore the expected shape.
895     rewriter.replaceOp(op, MaterializeDefaultRankSpecialization(
896                                rewriter, loc, op, max_target_rank));
897     return success();
898   }
899 
900  private:
901   int64_t max_target_rank;
902 };
903 
904 struct RankSpecializationToSCFPass
905     : public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
RankSpecializationToSCFPassmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass906   explicit RankSpecializationToSCFPass(int64_t max_target_rank)
907       : RankSpecializationToSCFPassBase<
908             RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
909     this->max_target_rank_ = max_target_rank;
910   }
911 
getDependentDialectsmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass912   void getDependentDialects(DialectRegistry &registry) const override {
913     registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
914                     shape::ShapeDialect, scf::SCFDialect>();
915   }
916 
runOnFunctionmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass917   void runOnFunction() override {
918     MLIRContext *ctx = &getContext();
919     RewritePatternSet patterns(ctx);
920     PopulateRankSpecializationToSCFPatterns(ctx, &patterns,
921                                             this->max_target_rank_);
922     if (failed(
923             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
924       return signalPassFailure();
925     }
926   }
927 };
928 
929 }  // namespace
930 
PopulateRankSpecializationClusterPatterns(MLIRContext * context,OwningRewritePatternList * patterns)931 void PopulateRankSpecializationClusterPatterns(
932     MLIRContext *context, OwningRewritePatternList *patterns) {
933   patterns->insert<MergeRankSpecializationClusterOpsPattern,
934                    RankSpecializationClusterPattern>(context);
935 }
936 
PopulateRankSpecializationToSCFPatterns(MLIRContext * context,OwningRewritePatternList * patterns,int64_t max_target_rank)937 void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
938                                              OwningRewritePatternList *patterns,
939                                              int64_t max_target_rank) {
940   patterns->insert<LowerRankSpecializationClusterPattern>(context,
941                                                           max_target_rank);
942   shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
943   shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context);
944   shape::AnyOp::getCanonicalizationPatterns(*patterns, context);
945 }
946 
createRankSpecializationClusterPass()947 std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
948   return std::make_unique<RankSpecializationClusterPass>();
949 }
950 
createRankSpecializationToSCFPass(int64_t max_target_rank)951 std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
952     int64_t max_target_rank) {
953   return std::make_unique<RankSpecializationToSCFPass>(max_target_rank);
954 }
955 
956 }  // namespace mhlo
957 }  // namespace mlir
958