1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14
15 ==============================================================================*/
16
17 #include <utility>
18
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
27 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
28 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
29 #include "mlir/Dialect/SCF/SCF.h"
30 #include "mlir/Dialect/Shape/IR/Shape.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/Block.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinOps.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Interfaces/InferTypeOpInterface.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43
44 namespace mlir {
45
46 /// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of
47 /// `mlir::Value`s.
operator <(const Value & lhs,const Value & rhs)48 static bool operator<(const Value &lhs, const Value &rhs) {
49 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
50 }
51
52 namespace mhlo {
53 namespace {
54
55 /// Identify clusters of operations that can be rank-specialized together. The
56 /// required traits for clustered operations are:
57 /// - Element-wise: All operations in the group must be element-wise. This
58 /// allows to reshape operands before applying the operations as well as
59 /// reshaping the result to the desired shape afterwards. This way, we can,
60 /// e.g., apply unary ops to a completely flattened operand and restore the
61 /// original shape afterwards.
62 /// - Broadcasting semantics: All operations must implement broadcasting
63 /// semantics. Most importantly, this allows extending operand shapes such
64 /// that they match in rank. Operations that require all their operands to
65 /// be of the same shape also fulfill this requirement.
66 /// - Shape reification: All operations must implement
67 /// `InferShapedTypeOpInterface`. This is later needed to compute and to
68 /// restore the desired result shape.
69
IsClusterable(Operation * op)70 bool IsClusterable(Operation *op) {
71 if (!llvm::isa<InferShapedTypeOpInterface>(op)) return false;
72 if (op->getNumOperands() == 0) return false;
73 return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
74 op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
75 op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
76 }
77
78 struct RankSpecializationClusterPattern : public RewritePattern {
RankSpecializationClusterPatternmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPattern79 explicit RankSpecializationClusterPattern(MLIRContext *ctx)
80 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
81
matchAndRewritemlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPattern82 LogicalResult matchAndRewrite(Operation *op,
83 PatternRewriter &rewriter) const override {
84 // Only apply to operations that have not been clustered yet.
85 if (op->getParentOfType<chlo::RankSpecializationClusterOp>()) {
86 return failure();
87 }
88
89 // Only cluster when rank specialization is needed.
90 if (!IsClusterable(op) || !llvm::any_of(op->getOperandTypes(), [](Type ty) {
91 return ty.isa<UnrankedTensorType>();
92 })) {
93 return failure();
94 }
95
96 // Collect all collectively rank specializable ops.
97 SmallVector<Operation *, 16> cluster;
98 llvm::SmallSet<Value, 16> operand_set;
99 llvm::SmallSet<Value, 16> result_set;
100
101 Operation *root_op = op;
102 while (root_op->getNextNode() != nullptr &&
103 IsClusterable(root_op->getNextNode()))
104 root_op = root_op->getNextNode();
105
106 Operation *it = root_op;
107 while (it != nullptr && IsClusterable(it)) {
108 // Find results that escape the cluster.
109 for (OpOperand &use : it->getUses()) {
110 if (!llvm::is_contained(cluster, use.getOwner()))
111 result_set.insert(use.get());
112 }
113
114 // Update cluster operands.
115 for (OpResult v : it->getResults()) operand_set.erase(Value(v));
116 for (OpOperand &v : it->getOpOperands()) operand_set.insert(v.get());
117
118 cluster.push_back(it);
119 it = it->getPrevNode();
120 }
121
122 // Create `RankSpecializationClusterOp`.
123 auto operands = llvm::to_vector<16>(operand_set);
124 auto results = llvm::to_vector<16>(result_set);
125 auto result_types = llvm::to_vector<16>(
126 llvm::map_range(result_set, [](Value v) { return v.getType(); }));
127 Location loc = op->getLoc();
128 auto cluster_op = rewriter.create<chlo::RankSpecializationClusterOp>(
129 loc, result_types, operands);
130
131 // Create body block.
132 auto operand_types = llvm::to_vector<16>(
133 llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
134 Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
135
136 // Copy operations into the body.
137 BlockAndValueMapping bvm;
138 for (auto it : llvm::zip(operands, block->getArguments()))
139 bvm.map(std::get<0>(it), std::get<1>(it));
140 rewriter.setInsertionPointToStart(block);
141 for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
142
143 // Create `RankSpecializationClusterYieldOp`.
144 auto mapped_results = llvm::to_vector<16>(
145 llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
146 rewriter.create<chlo::RankSpecializationClusterYieldOp>(loc,
147 mapped_results);
148
149 // Replace original ops with the new results.
150 for (auto it : llvm::zip(results, cluster_op.results()))
151 bvm.map(std::get<0>(it), std::get<1>(it));
152 for (Operation *it : cluster) {
153 if (it->getUses().empty()) {
154 rewriter.eraseOp(it);
155 continue;
156 }
157 auto replacements = llvm::to_vector<16>(llvm::map_range(
158 it->getResults(), [&](Value v) { return bvm.lookup(v); }));
159 rewriter.replaceOp(it, replacements);
160 }
161
162 return success();
163 }
164 };
165
166 struct MergeRankSpecializationClusterOpsPattern
167 : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
168 using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
169
matchAndRewritemlir::mhlo::__anonf2c9eda60111::MergeRankSpecializationClusterOpsPattern170 LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
171 PatternRewriter &rewriter) const override {
172 auto preceding_op =
173 llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
174 op->getPrevNode());
175 if (!preceding_op) return failure();
176 Block *body = op.getBody();
177 Block *preceding_body = preceding_op.getBody();
178 auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
179 op.getBody()->getTerminator());
180 auto preceding_yield_op =
181 llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
182 preceding_op.getBody()->getTerminator());
183
184 // Merge cluster operands. Consider only those operands of the second
185 // cluster that do not originate in the preceding cluster.
186 SmallVector<Value, 8> new_operands;
187 for (Value v : preceding_op.operands()) new_operands.push_back(v);
188 for (Value v : op.operands()) {
189 if (v.getDefiningOp() != preceding_op &&
190 !llvm::is_contained(preceding_op.operands(), v)) {
191 new_operands.push_back(v);
192 }
193 }
194
195 // Merge cluster results. Consider only those results of the preceding
196 // cluster that are not exclusively used as operands to the second cluster.
197 SmallVector<Value, 8> new_unmapped_results;
198 for (auto it :
199 llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
200 Value result, inner_result;
201 std::tie(result, inner_result) = it;
202 if (!llvm::all_of(result.getUsers(),
203 [&](Operation *user) { return user == op; })) {
204 new_unmapped_results.push_back(inner_result);
205 }
206 }
207 for (Value v : yield_op.results()) new_unmapped_results.push_back(v);
208
209 // Create merged cluster op.
210 rewriter.setInsertionPoint(preceding_op);
211 auto loc = op.getLoc();
212 auto result_types = llvm::to_vector<16>(llvm::map_range(
213 new_unmapped_results, [](Value v) { return v.getType(); }));
214 auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
215 loc, result_types, new_operands);
216 auto operand_types = llvm::to_vector<16>(
217 llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
218 Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types);
219 rewriter.setInsertionPointToStart(new_body);
220
221 // Map operands and copy operations of the preceding cluster into the new
222 // body.
223 BlockAndValueMapping bvm;
224 for (auto it : llvm::enumerate(preceding_body->getArguments()))
225 bvm.map(it.value(), new_body->getArgument(it.index()));
226 for (Operation &nested_op : preceding_body->without_terminator())
227 rewriter.clone(nested_op, bvm);
228
229 // Map operands and copy operations of the second cluster. If they result
230 // from the preceeding cluster, we can simply map the corresponding value
231 // internally.
232 for (auto it : llvm::zip(body->getArguments(), op.operands())) {
233 Value block_arg, operand;
234 std::tie(block_arg, operand) = it;
235 if (operand.getDefiningOp() == preceding_op) {
236 auto where = llvm::find(preceding_op.results(), operand);
237 assert(where.getBase() != nullptr && "expected to find ");
238 bvm.map(block_arg,
239 bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
240 } else {
241 auto where = llvm::find(new_op.operands(), operand);
242 bvm.map(block_arg, new_body->getArgument(where.getIndex()));
243 }
244 }
245 for (Operation &nested_op : body->without_terminator()) {
246 rewriter.clone(nested_op, bvm);
247 }
248
249 // Yield inner results.
250 rewriter.create<chlo::RankSpecializationClusterYieldOp>(
251 loc,
252 llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
253 return bvm.lookupOrDefault(v);
254 })));
255
256 // Replace the two cluster ops with the new corresponding results.
257 SmallVector<Value, 8> preceding_op_replacements;
258 int64_t i = 0;
259 for (Value result : preceding_op.results()) {
260 Value replacement = nullptr;
261 if (!llvm::all_of(result.getUsers(),
262 [&](Operation *user) { return user == op; })) {
263 replacement = new_op->getResult(i++);
264 }
265 preceding_op_replacements.push_back(replacement);
266 }
267 ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
268 rewriter.replaceOp(op, op_replacements);
269 rewriter.replaceOp(preceding_op, preceding_op_replacements);
270
271 return success();
272 }
273 };
274
275 struct RankSpecializationClusterPass
276 : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
getDependentDialectsmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPass277 void getDependentDialects(DialectRegistry ®istry) const override {
278 registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
279 }
280
runOnFunctionmlir::mhlo::__anonf2c9eda60111::RankSpecializationClusterPass281 void runOnFunction() override {
282 MLIRContext *ctx = &getContext();
283 RewritePatternSet patterns(ctx);
284 mhlo::PopulateRankSpecializationClusterPatterns(ctx, &patterns);
285 if (failed(
286 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
287 return signalPassFailure();
288 }
289 }
290 };
291
292 /// Lower rank specialization cluster to SCF.
293
IsScalarTensorType(Type ty)294 bool IsScalarTensorType(Type ty) {
295 auto ranked_ty = ty.dyn_cast<RankedTensorType>();
296 return ranked_ty && ranked_ty.getRank() == 0;
297 }
298
IsScalarShapeType(Type ty)299 bool IsScalarShapeType(Type ty) {
300 return ty.cast<RankedTensorType>().getDimSize(0) == 0;
301 }
302
DeriveRankedTensorTypes(Type ty,int64_t rank)303 Type DeriveRankedTensorTypes(Type ty, int64_t rank) {
304 auto tensor_ty = ty.dyn_cast<TensorType>();
305 if (!tensor_ty) return ty;
306 SmallVector<int64_t, 8> shape(rank, ShapedType::kDynamicSize);
307 return RankedTensorType::get(shape, tensor_ty.getElementType());
308 }
309
DeriveUnrankedTensorTypes(Type ty)310 Type DeriveUnrankedTensorTypes(Type ty) {
311 if (auto ranked_ty = ty.dyn_cast<RankedTensorType>())
312 return UnrankedTensorType::get(ranked_ty.getElementType());
313 return ty;
314 }
315
MaterializeRankedOperations(OpBuilder & b,Location loc,BlockAndValueMapping & bvm,chlo::RankSpecializationClusterOp op)316 SmallVector<Value, 8> MaterializeRankedOperations(
317 OpBuilder &b, Location loc, BlockAndValueMapping &bvm,
318 chlo::RankSpecializationClusterOp op) {
319 // Create ranked operations.
320 for (Operation &nested_op : op.getBody()->without_terminator()) {
321 auto mapped_operands = llvm::to_vector<4>(llvm::map_range(
322 nested_op.getOperands(), [&](Value v) { return bvm.lookup(v); }));
323 int64_t target_rank = 0;
324 for (Value v : mapped_operands) {
325 target_rank =
326 std::max(target_rank, v.getType().cast<RankedTensorType>().getRank());
327 }
328 auto ranked_result_types = llvm::to_vector<2>(llvm::map_range(
329 nested_op.getResultTypes(),
330 [&](Type ty) { return DeriveRankedTensorTypes(ty, target_rank); }));
331 OperationState ranked_op_state(loc, nested_op.getName().getStringRef(),
332 mapped_operands, ranked_result_types,
333 nested_op.getAttrs());
334 Operation *ranked_op = b.createOperation(ranked_op_state);
335 for (auto it : llvm::zip(nested_op.getResults(), ranked_op->getResults()))
336 bvm.map(std::get<0>(it), std::get<1>(it));
337 }
338
339 // Collect ranked results.
340 auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
341 op.getBody()->getTerminator());
342 return llvm::to_vector<8>(llvm::map_range(
343 yield_op.results(), [&](Value v) { return bvm.lookup(v); }));
344 }
345
MaterializeFinalReshape(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange unshaped_results)346 SmallVector<Value, 8> MaterializeFinalReshape(
347 PatternRewriter &rewriter, Location loc,
348 chlo::RankSpecializationClusterOp op, ValueRange unshaped_results) {
349 auto yield_op = llvm::cast<chlo::RankSpecializationClusterYieldOp>(
350 op.getBody()->getTerminator());
351 assert(unshaped_results.size() == 1 && yield_op.results().size() == 1 &&
352 "Currently, rank specialization supports only one result.");
353
354 // Reify result shape.
355 Operation *last_op_before_shape_reification = op->getPrevNode();
356 SmallVector<Value, 1> result_shape;
357 Value original_result = yield_op.results().front();
358 auto original_result_iface =
359 llvm::cast<InferShapedTypeOpInterface>(original_result.getDefiningOp());
360 if (failed(original_result_iface.reifyReturnTypeShapes(
361 rewriter, original_result_iface->getOperands(), result_shape))) {
362 return {};
363 }
364
365 // Materialize final reshape.
366 Value unshaped_result = unshaped_results.front();
367 Value result = rewriter.create<mhlo::DynamicReshapeOp>(
368 loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
369 unshaped_result, result_shape.front());
370
371 // Reify shapes until they are independent of operations in the original
372 // cluster.
373 {
374 Operation *it = result_shape.front().getDefiningOp();
375 while (it != nullptr && it != last_op_before_shape_reification) {
376 bool advanced = false;
377 if (auto shape_of_op = llvm::dyn_cast<shape::ShapeOfOp>(it)) {
378 Operation *def = shape_of_op.arg().getDefiningOp();
379 if (def && def->getBlock() == op.getBody()) {
380 // Resolve `shape_of` op because it still depends on operation in the
381 // original cluster.
382 OpBuilder::InsertionGuard guard(rewriter);
383 rewriter.setInsertionPoint(shape_of_op);
384 SmallVector<Value, 1> tmp_shape;
385 auto iface = llvm::cast<InferShapedTypeOpInterface>(def);
386 if (failed(iface.reifyReturnTypeShapes(rewriter, iface->getOperands(),
387 tmp_shape)))
388 return {};
389 rewriter.replaceOp(shape_of_op, tmp_shape.front());
390
391 // Continue, including the newly created operations.
392 it = tmp_shape.front().getDefiningOp();
393 advanced = true;
394 }
395 }
396
397 // Skip op, otherwise.
398 if (!advanced) it = it->getPrevNode();
399 }
400 }
401
402 // Replace all remaining uses of the original cluster's block args.
403 for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) {
404 Value operand, barg;
405 std::tie(operand, barg) = it;
406 barg.replaceUsesWithIf(operand, [&](OpOperand &operand) {
407 return operand.getOwner()->getBlock() != op.getBody();
408 });
409 }
410
411 return {result};
412 }
413
MaterializeFlatShape(OpBuilder & b,Location loc,ValueRange same_shapes)414 Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) {
415 assert(!same_shapes.empty() && "Expected at least one shape.");
416 Value shape = same_shapes.size() == 1
417 ? same_shapes.front()
418 : b.create<shape::AnyOp>(loc, same_shapes.front().getType(),
419 same_shapes);
420 return b.create<tensor::FromElementsOp>(
421 loc,
422 b.create<shape::NumElementsOp>(loc, b.getIndexType(), shape).result());
423 }
424
MaterializeScalarRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,ValueRange non_scalars_of_same_shape,function_ref<void (OpBuilder &,Location)> else_builder_fn)425 Value MaterializeScalarRankSpecializationCase(
426 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
427 const SmallVector<Value, 8> &shapes, ValueRange non_scalars_of_same_shape,
428 function_ref<void(OpBuilder &, Location)> else_builder_fn) {
429 // Materialize predicate: All operands are scalars, except the expected
430 // non-scalars.
431 Value one = b.create<ConstantIndexOp>(loc, 1);
432 Value all_others_are_scalar;
433 for (auto it : llvm::zip(op.operands(), shapes)) {
434 Value operand, shape;
435 std::tie(operand, shape) = it;
436 if (llvm::is_contained(non_scalars_of_same_shape, operand) ||
437 IsScalarTensorType(operand.getType())) {
438 continue;
439 }
440 auto literal =
441 b.create<CmpIOp>(loc, CmpIPredicate::eq,
442 b.create<shape::NumElementsOp>(loc, shape), one);
443 all_others_are_scalar =
444 all_others_are_scalar
445 ? b.create<mlir::AndOp>(loc, all_others_are_scalar, literal)
446 .getResult()
447 : literal.result();
448 }
449
450 auto if_op = b.create<scf::IfOp>(
451 loc, op->getResultTypes(), all_others_are_scalar,
452 [&](OpBuilder &b, Location loc) {
453 // Compute flat non-scalar shape.
454 SmallVector<Value, 4> non_scalar_shapes;
455 for (auto it : llvm::zip(op.operands(), shapes)) {
456 Value operand, shape;
457 std::tie(operand, shape) = it;
458 if (llvm::is_contained(non_scalars_of_same_shape, operand))
459 non_scalar_shapes.push_back(shape);
460 }
461 Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
462
463 // Derive ranked operands.
464 auto ranked_operands =
465 llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
466 if (IsScalarTensorType(v.getType())) return v;
467 if (!llvm::is_contained(non_scalars_of_same_shape, v)) {
468 return b
469 .create<mhlo::ReshapeOp>(
470 loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0),
471 v)
472 .getResult();
473 }
474 return b
475 .create<mhlo::DynamicReshapeOp>(
476 loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
477 flat_shape)
478 .getResult();
479 }));
480
481 // Materialize ranked variants for the element-wise operations.
482 BlockAndValueMapping bvm;
483 for (auto it : llvm::zip(op.getBody()->getArguments(), ranked_operands))
484 bvm.map(std::get<0>(it), std::get<1>(it));
485 Value unshaped_result =
486 MaterializeRankedOperations(b, loc, bvm, op).front();
487
488 // Return as unranked tensor for compatibility with the other cases.
489 b.create<scf::YieldOp>(
490 loc, b.create<tensor::CastOp>(
491 loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
492 unshaped_result)
493 .dest());
494 },
495 else_builder_fn);
496
497 return if_op.results().front();
498 }
499
MaterializeEqualShapesRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,function_ref<void (OpBuilder &,Location)> else_builder_fn)500 Value MaterializeEqualShapesRankSpecializationCase(
501 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
502 const SmallVector<Value, 8> &shapes,
503 function_ref<void(OpBuilder &, Location)> else_builder_fn) {
504 // Materialize all shapes equal predicate.
505 Value all_shapes_eq_or_scalar;
506 auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
507 shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
508 assert(
509 non_scalar_shapes.size() >= 2 &&
510 "Equal shapes strategy requires at least two non-scalar operand shapes.");
511 for (Value s : llvm::drop_begin(non_scalar_shapes)) {
512 auto literal =
513 b.create<shape::ShapeEqOp>(loc, non_scalar_shapes.front(), s);
514 all_shapes_eq_or_scalar =
515 all_shapes_eq_or_scalar
516 ? b.create<mlir::AndOp>(loc, all_shapes_eq_or_scalar, literal)
517 .result()
518 : literal;
519 }
520
521 auto if_op = b.create<scf::IfOp>(
522 loc, op->getResultTypes(), all_shapes_eq_or_scalar,
523 [&](OpBuilder &b, Location loc) {
524 // Flatten non-scalar operands.
525 Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes);
526 auto flat_operands =
527 llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
528 if (IsScalarTensorType(v.getType())) return v;
529 return b
530 .create<mhlo::DynamicReshapeOp>(
531 loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v,
532 flat_shape)
533 .result();
534 }));
535
536 // Materialize ranked variants for the element-wise operations.
537 BlockAndValueMapping bvm;
538 for (auto it : llvm::zip(op.getBody()->getArguments(), flat_operands))
539 bvm.map(std::get<0>(it), std::get<1>(it));
540 Value unshaped_result =
541 MaterializeRankedOperations(b, loc, bvm, op).front();
542
543 // Return as unranked tensor for compatibility with the other cases.
544 b.create<scf::YieldOp>(
545 loc, b.create<tensor::CastOp>(
546 loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
547 unshaped_result)
548 .dest());
549 },
550 else_builder_fn);
551
552 return if_op.results().front();
553 }
554
MaterializeTargetRankSpecializationCase(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t target_rank)555 Value MaterializeTargetRankSpecializationCase(
556 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
557 const SmallVector<Value, 8> &shapes, int64_t target_rank) {
558 // Reshape unranked operands to match the target rank.
559 RankedTensorType extent_tensor_ty =
560 shape::getExtentTensorType(b.getContext(), target_rank);
561 Value all_ones_shape = b.create<shape::ConstShapeOp>(
562 loc, extent_tensor_ty,
563 mlir::DenseIntElementsAttr::get(extent_tensor_ty,
564 SmallVector<int64_t, 6>(target_rank, 1)));
565 SmallVector<Value, 8> ranked_operands;
566 for (auto it : llvm::zip(op.operands(), shapes)) {
567 Value operand, shape;
568 std::tie(operand, shape) = it;
569 if (operand.getType().isa<RankedTensorType>()) {
570 ranked_operands.push_back(operand);
571 continue;
572 }
573 Value ranked_shape = b.create<tensor::CastOp>(
574 loc, extent_tensor_ty,
575 b.create<shape::BroadcastOp>(loc,
576 shape::getExtentTensorType(b.getContext()),
577 shape, all_ones_shape,
578 /*error=*/nullptr));
579 ranked_operands.push_back(b.create<mhlo::DynamicReshapeOp>(
580 loc, DeriveRankedTensorTypes(operand.getType(), target_rank), operand,
581 ranked_shape));
582 }
583
584 // Materialize ranked versions of the element-wise operations.
585 BlockAndValueMapping bvm;
586 for (auto it : llvm::zip(op.body().front().getArguments(), ranked_operands))
587 bvm.map(std::get<0>(it), std::get<1>(it));
588
589 // Return as unranked for compatibility with other target ranks.
590 auto unshaped_result = MaterializeRankedOperations(b, loc, bvm, op).front();
591 return b.create<tensor::CastOp>(
592 loc, DeriveUnrankedTensorTypes(unshaped_result.getType()),
593 unshaped_result);
594 }
595
RecusivelyMaterializeTargetRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,Value max_rank,int64_t min_target_rank,int64_t max_target_rank)596 Value RecusivelyMaterializeTargetRankSpecializationCases(
597 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
598 const SmallVector<Value, 8> &shapes, Value max_rank,
599 int64_t min_target_rank, int64_t max_target_rank) {
600 Value condition =
601 b.create<CmpIOp>(loc, CmpIPredicate::ule, max_rank,
602 b.create<ConstantIndexOp>(loc, min_target_rank));
603
604 // If only a unique target rank is left, we can lower to an assert instead
605 // of the usual if operation.
606 if (min_target_rank == max_target_rank) {
607 b.create<AssertOp>(loc, condition,
608 "Input for dynamic binary or n-ary op lowering was of "
609 "a rank greater than " +
610 std::to_string(max_target_rank));
611 return MaterializeTargetRankSpecializationCase(b, loc, op, shapes,
612 min_target_rank);
613 }
614
615 // Materialize IR for the smallest considered target rank.
616 auto if_op = b.create<scf::IfOp>(loc, op->getResultTypes(), condition,
617 /*withElseRegion=*/true);
618 auto then_builder = if_op.getThenBodyBuilder();
619 then_builder.create<scf::YieldOp>(
620 loc, MaterializeTargetRankSpecializationCase(then_builder, loc, op,
621 shapes, min_target_rank));
622
623 // Recurse for all remaining target ranks.
624 auto else_builder = if_op.getElseBodyBuilder();
625 else_builder.create<scf::YieldOp>(
626 loc, RecusivelyMaterializeTargetRankSpecializationCases(
627 else_builder, loc, op, shapes, max_rank, min_target_rank + 1,
628 max_target_rank));
629
630 return if_op.results().front();
631 }
632
MaterializeGenericRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t max_target_rank)633 Value MaterializeGenericRankSpecializationCases(
634 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
635 const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
636 // Get the minimum broadcast shapes of the operands.
637 auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
638 shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
639 auto min_bcast_shapes_op = b.create<chlo::MinimumBroadcastShapesOp>(
640 loc,
641 SmallVector<Type, 8>(non_scalar_shapes.size(),
642 shape::getExtentTensorType(b.getContext())),
643 non_scalar_shapes);
644
645 // Find the maximum rank among the reduced operand shapes.
646 Value max_rank;
647 for (Value shape : min_bcast_shapes_op.results()) {
648 Value rank = b.create<shape::RankOp>(loc, b.getIndexType(), shape);
649 if (!max_rank) {
650 max_rank = rank;
651 } else {
652 max_rank = b.create<mlir::SelectOp>(
653 loc, b.create<CmpIOp>(loc, CmpIPredicate::sgt, max_rank, rank),
654 max_rank, rank);
655 }
656 }
657
658 // Collect reduced shapes.
659 SmallVector<Value, 8> reduced_shapes;
660 auto it = min_bcast_shapes_op.result_begin();
661 for (Value s : shapes) {
662 if (IsScalarShapeType(s.getType())) {
663 reduced_shapes.push_back(s);
664 } else {
665 reduced_shapes.push_back(*it++);
666 }
667 }
668
669 // Materialize rank specialization for ranks 1, ...
670 return RecusivelyMaterializeTargetRankSpecializationCases(
671 b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1,
672 max_target_rank);
673 }
674
MaterializeDefaultRankSpecializationCases(OpBuilder & b,Location loc,chlo::RankSpecializationClusterOp op,const SmallVector<Value,8> & shapes,int64_t max_target_rank)675 Value MaterializeDefaultRankSpecializationCases(
676 OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
677 const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
678 return MaterializeEqualShapesRankSpecializationCase(
679 b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
680 b.create<scf::YieldOp>(loc, MaterializeGenericRankSpecializationCases(
681 b, loc, op, shapes, max_target_rank));
682 });
683 }
684
685 SmallVector<Value, 8>
MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,ValueRange non_scalars_of_same_shape)686 MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
687 PatternRewriter &rewriter, Location loc,
688 chlo::RankSpecializationClusterOp op,
689 ValueRange non_scalars_of_same_shape) {
690 // Compute flat operand shape.
691 auto non_scalar_shapes = llvm::to_vector<4>(
692 llvm::map_range(non_scalars_of_same_shape, [&](Value v) {
693 return rewriter.create<shape::ShapeOfOp>(loc, v).result();
694 }));
695 Value flat_shape = MaterializeFlatShape(rewriter, loc, non_scalar_shapes);
696
697 // Materialize ranked variants for the element-wise operations.
698 BlockAndValueMapping bvm;
699 for (auto it : llvm::zip(op.getBody()->getArguments(), op.operands())) {
700 Value operand;
701 Value bb_arg;
702 std::tie(bb_arg, operand) = it;
703 if (!IsScalarTensorType(operand.getType())) {
704 assert(llvm::is_contained(non_scalars_of_same_shape, operand) &&
705 "Expected all non-scalars in the same shape equivalence class.");
706 operand = rewriter.create<mhlo::DynamicReshapeOp>(
707 loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand,
708 flat_shape);
709 }
710 bvm.map(bb_arg, operand);
711 }
712 SmallVector<Value, 8> unshaped_results =
713 MaterializeRankedOperations(rewriter, loc, bvm, op);
714
715 // Restore the results' expected shape.
716 Value shape = non_scalar_shapes.front();
717 return llvm::to_vector<8>(llvm::map_range(unshaped_results, [&](Value v) {
718 return rewriter
719 .create<mhlo::DynamicReshapeOp>(
720 loc, DeriveUnrankedTensorTypes(v.getType()), v, shape)
721 .result();
722 }));
723 }
724
MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,SmallVector<SmallVector<Value,4>,4> non_scalar_eqs,int64_t max_target_rank)725 Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
726 PatternRewriter &rewriter, Location loc,
727 chlo::RankSpecializationClusterOp op,
728 SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs,
729 int64_t max_target_rank) {
730 assert(non_scalar_eqs.size() == 2 &&
731 "Expect two non-scalar equivalence classes.");
732 auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
733 return rewriter.create<shape::ShapeOfOp>(loc, v).result();
734 }));
735 ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0];
736 ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1];
737
738 // Materialize all the different cases.
739 Value unshaped_result = MaterializeScalarRankSpecializationCase(
740 rewriter, loc, op, shapes, rhs_non_scalar_eqs,
741 [&](OpBuilder &b, Location loc) {
742 b.create<scf::YieldOp>(
743 loc, MaterializeScalarRankSpecializationCase(
744 b, loc, op, shapes, lhs_non_scalar_eqs,
745 [&](OpBuilder &b, Location loc) {
746 b.create<scf::YieldOp>(
747 loc, MaterializeDefaultRankSpecializationCases(
748 b, loc, op, shapes, max_target_rank));
749 }));
750 });
751
752 // Materialize final reshape once and for all rank specialization cases.
753 return MaterializeFinalReshape(rewriter, loc, op, unshaped_result).front();
754 }
755
756 // Materialize rank generic rank specialization.
MaterializeDefaultRankSpecialization(PatternRewriter & rewriter,Location loc,chlo::RankSpecializationClusterOp op,int64_t max_target_rank)757 Value MaterializeDefaultRankSpecialization(PatternRewriter &rewriter,
758 Location loc,
759 chlo::RankSpecializationClusterOp op,
760 int64_t max_target_rank) {
761 auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
762 return rewriter.create<shape::ShapeOfOp>(loc, v).result();
763 }));
764
765 // Materialize all the different cases.
766 Value unshaped_result = MaterializeDefaultRankSpecializationCases(
767 rewriter, loc, op, shapes, max_target_rank);
768
769 // Materialize final reshape once and for all rank specialization cases.
770 return MaterializeFinalReshape(rewriter, loc, op, unshaped_result).front();
771 }
772
773 // This is a very limited form of shape inference. It is correct but incomplete.
FindNonScalarShapeEquivalences(chlo::RankSpecializationClusterOp op)774 SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
775 chlo::RankSpecializationClusterOp op) {
776 llvm::EquivalenceClasses<Value> eqs;
777
778 // Bridge the equivalences between operands and block arguments.
779 for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments()))
780 eqs.unionSets(std::get<0>(it), std::get<1>(it));
781
782 // Find equalities through `SameOperandsAndResultShape` trait.
783 auto union_sets = [&](ValueRange vs) {
784 if (vs.empty()) return;
785 Value repr = vs.front();
786 for (Value v : vs.drop_front()) eqs.unionSets(repr, v);
787 };
788 for (Operation &nested_op : op.getBody()->without_terminator()) {
789 if (nested_op.hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) {
790 union_sets(nested_op.getOperands());
791 union_sets(nested_op.getResults());
792 if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
793 eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
794 }
795 }
796
797 // Find shape equalities through surrounding constraints.
798 if (auto assuming_op = op->getParentOfType<shape::AssumingOp>()) {
799 SmallVector<Operation *, 8> queue;
800 auto append_if_not_null = [&](Operation *op) {
801 if (op != nullptr) queue.push_back(op);
802 };
803 append_if_not_null(assuming_op.witness().getDefiningOp());
804 while (!queue.empty()) {
805 Operation *it = queue.pop_back_val();
806 if (auto assuming_all_op = llvm::dyn_cast<shape::AssumingAllOp>(it)) {
807 for (Value v : assuming_all_op.inputs())
808 append_if_not_null(v.getDefiningOp());
809 } else if (auto cstr_eq_op = llvm::dyn_cast<shape::CstrEqOp>(it)) {
810 Value ref_arg;
811 for (Value v : cstr_eq_op.shapes()) {
812 if (auto shape_of_op =
813 dyn_cast_or_null<shape::ShapeOfOp>(v.getDefiningOp())) {
814 if (!ref_arg) {
815 ref_arg = shape_of_op.arg();
816 } else {
817 eqs.unionSets(ref_arg, shape_of_op.arg());
818 }
819 }
820 }
821 }
822 }
823 }
824
825 // Find equalities through special knowledge of ops.
826 // TODO(frgossen): Remove this when these shape equalities can be inferred
827 // from surrounding shape constraints.
828 for (Operation &nested_op : op.getBody()->without_terminator()) {
829 if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
830 union_sets(
831 {select_op.on_true(), select_op.on_false(), select_op.getResult()});
832 } else if (auto clamp_op = llvm::dyn_cast<mhlo::ClampOp>(nested_op)) {
833 union_sets({clamp_op.operand(), clamp_op.getResult()});
834 }
835 }
836
837 // Convert to a list-like equivalence class representation.
838 SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs;
839 for (Value v : op.operands()) {
840 if (IsScalarTensorType(v.getType())) continue;
841 bool inserted = false;
842 for (auto &eq_class : non_scalar_eqs) {
843 if (eqs.isEquivalent(eq_class.front(), v)) {
844 eq_class.push_back(v);
845 inserted = true;
846 break;
847 }
848 }
849 if (!inserted) non_scalar_eqs.push_back(SmallVector<Value, 4>({v}));
850 }
851
852 return non_scalar_eqs;
853 }
854
855 struct LowerRankSpecializationClusterPattern
856 : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
LowerRankSpecializationClusterPatternmlir::mhlo::__anonf2c9eda60111::LowerRankSpecializationClusterPattern857 LowerRankSpecializationClusterPattern(MLIRContext *ctx,
858 int64_t max_target_rank)
859 : OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
860 max_target_rank(max_target_rank) {}
861
matchAndRewritemlir::mhlo::__anonf2c9eda60111::LowerRankSpecializationClusterPattern862 LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
863 PatternRewriter &rewriter) const override {
864 // Restoring the result shape currently relies on all operands being used
865 // for a single result. The result shape is then the broadcasted shape of
866 // all operands.
867 if (op.getNumResults() != 1) return failure();
868
869 // If there is only a single non-scalar shape equivalence class, we can
870 // flatten that operands completely.
871 SmallVector<SmallVector<Value, 4>, 4> non_scalar_eqs =
872 FindNonScalarShapeEquivalences(op);
873 Location loc = op.getLoc();
874 if (non_scalar_eqs.size() == 1) {
875 rewriter.replaceOp(
876 op,
877 MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass(
878 rewriter, loc, op, non_scalar_eqs.front()));
879 return success();
880 }
881
882 // If there are exactly two non-scalar shape equivalence classes, we can
883 // consider two extra cases: If either of the operand classes turns out to
884 // be all-scalars at runtime, we can, again, flatten all operands.
885 if (non_scalar_eqs.size() == 2) {
886 rewriter.replaceOp(
887 op,
888 MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses(
889 rewriter, loc, op, non_scalar_eqs, max_target_rank));
890 return success();
891 }
892
893 // For all other cases, reshape the operands to match in rank, apply the
894 // operation, and restore the expected shape.
895 rewriter.replaceOp(op, MaterializeDefaultRankSpecialization(
896 rewriter, loc, op, max_target_rank));
897 return success();
898 }
899
900 private:
901 int64_t max_target_rank;
902 };
903
904 struct RankSpecializationToSCFPass
905 : public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
RankSpecializationToSCFPassmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass906 explicit RankSpecializationToSCFPass(int64_t max_target_rank)
907 : RankSpecializationToSCFPassBase<
908 RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
909 this->max_target_rank_ = max_target_rank;
910 }
911
getDependentDialectsmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass912 void getDependentDialects(DialectRegistry ®istry) const override {
913 registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
914 shape::ShapeDialect, scf::SCFDialect>();
915 }
916
runOnFunctionmlir::mhlo::__anonf2c9eda60111::RankSpecializationToSCFPass917 void runOnFunction() override {
918 MLIRContext *ctx = &getContext();
919 RewritePatternSet patterns(ctx);
920 PopulateRankSpecializationToSCFPatterns(ctx, &patterns,
921 this->max_target_rank_);
922 if (failed(
923 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
924 return signalPassFailure();
925 }
926 }
927 };
928
929 } // namespace
930
PopulateRankSpecializationClusterPatterns(MLIRContext * context,OwningRewritePatternList * patterns)931 void PopulateRankSpecializationClusterPatterns(
932 MLIRContext *context, OwningRewritePatternList *patterns) {
933 patterns->insert<MergeRankSpecializationClusterOpsPattern,
934 RankSpecializationClusterPattern>(context);
935 }
936
PopulateRankSpecializationToSCFPatterns(MLIRContext * context,OwningRewritePatternList * patterns,int64_t max_target_rank)937 void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
938 OwningRewritePatternList *patterns,
939 int64_t max_target_rank) {
940 patterns->insert<LowerRankSpecializationClusterPattern>(context,
941 max_target_rank);
942 shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context);
943 shape::ShapeOfOp::getCanonicalizationPatterns(*patterns, context);
944 shape::AnyOp::getCanonicalizationPatterns(*patterns, context);
945 }
946
createRankSpecializationClusterPass()947 std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
948 return std::make_unique<RankSpecializationClusterPass>();
949 }
950
createRankSpecializationToSCFPass(int64_t max_target_rank)951 std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
952 int64_t max_target_rank) {
953 return std::make_unique<RankSpecializationToSCFPass>(max_target_rank);
954 }
955
956 } // namespace mhlo
957 } // namespace mlir
958