1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "llvm/Support/Debug.h"
23 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
25 #include "mlir-hlo/utils/cycle_detector.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
27 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
28 #include "mlir/IR/Matchers.h"
29 #include "mlir/Pass/Pass.h" // TF:local_config_mlir
30 #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
31
32 // This pass has similar functionality of the fusion pass in XLA stack.
33 // However, unlike XLA, it targets the fully dynamic shape scenario.
34 // Currently, it implements the kLoop and kInput fusion templates.
35 // During conversion, it tries to greedily find kLoop/kInput fusion
36 // patterns.
37 //
38 // Similar to XLA, this pass supports fusion pattern having multiple outputs
39 // if all the shape of outputs are consistent. Following are some examples.
40 //
41 // kLoop kInput
42 // +----+ +----+ +----+ +----+ +----+ +----+
43 // |elem| |elem| |elem| |elem<----+elem+---->elem+----+
44 // +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
45 // | | | | | |
46 // | | | | |
47 // +-v--+ | +-v--+ +--v---+ +--v---+ |
48 // |elem+<---+----<+elem| |reduce| |reduce| |
49 // +-+--+ +-+--+ +--+---+ +--+---+ |
50 // | | | | |
51 // | | | | |
52 // v v v v v
53 //
54 // To this end, we also add an simple shape constraint analysis phase.
55 // For kLoop fusion template, it requires all the outputs of the fused
56 // pattern have the same shape. However, we don't know the actual value
57 // of the shape at the compile time in the dynamic shape world.
58 // Fortunately, we could still infer the relationship among different ops
59 // according to their shape constrain traits. Currently, We only consider
60 // shape equality propagation for elementwise ops (assuming that implicit
61 // shape broadcast is forbidden). The above process could be built on the
62 // shape dialect once it is ready.
63
64 namespace mlir {
65 namespace mhlo {
66 namespace {
67
68 using llvm::EquivalenceClasses;
69 using FusionPattern = std::vector<Operation*>;
70 using FusionPlan = std::vector<FusionPattern>;
71
72 // To support using EquivalenceClasses for Value
73 class ValueWrapper {
74 public:
ValueWrapper(Value value)75 explicit ValueWrapper(Value value) : value_(std::move(value)) {}
76
getValue() const77 Value getValue() const { return value_; }
78
operator ==(const ValueWrapper & rhs) const79 bool operator==(const ValueWrapper& rhs) const {
80 return getValue() == rhs.getValue();
81 }
82
83 private:
84 Value value_;
85 };
86
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)87 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
88 auto lhs_value = lhs.getValue().getAsOpaquePointer();
89 auto rhs_value = rhs.getValue().getAsOpaquePointer();
90 return lhs_value < rhs_value;
91 }
92
IsFusible(Operation * op)93 bool IsFusible(Operation* op) {
94 if (matchPattern(op, m_Constant())) {
95 return true;
96 }
97 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
98 return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
99 op_fusibility.isFusibleWithConsumer());
100 }
101
GetInputsOfFusionPattern(const FusionPattern & pattern)102 SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
103 SmallVector<Value, 4> inputs;
104 DenseSet<Value> input_set;
105 DenseSet<Operation*> op_set;
106 for (Operation* op : pattern) {
107 bool inserted = op_set.insert(op).second;
108 (void)inserted;
109 assert(inserted && "FusionPattern contains duplicate operations");
110 }
111
112 for (Operation* op : pattern) {
113 for (Value operand : op->getOperands()) {
114 Operation* operand_op = operand.getDefiningOp();
115 if (op_set.find(operand_op) != op_set.end()) {
116 // skip if defining op is in the pattern
117 continue;
118 }
119 if (input_set.insert(operand).second) {
120 inputs.push_back(operand);
121 }
122 }
123 }
124 return inputs;
125 }
126
GetOutputsOfFusionPattern(const FusionPattern & pattern)127 SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
128 SmallVector<Value, 4> outputs;
129 DenseSet<Operation*> op_set;
130 for (Operation* op : pattern) {
131 bool inserted = op_set.insert(op).second;
132 (void)inserted;
133 assert(inserted && "FusionPattern contains duplicate operations");
134 }
135
136 for (Operation* op : pattern) {
137 for (Value result : op->getResults()) {
138 bool has_external_user = llvm::any_of(
139 result.getUses(),
140 [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
141 if (has_external_user) {
142 outputs.push_back(result);
143 }
144 }
145 }
146 return outputs;
147 }
148
MergeFusionPattern(const FusionPattern & lhs,const FusionPattern & rhs)149 FusionPattern MergeFusionPattern(const FusionPattern& lhs,
150 const FusionPattern& rhs) {
151 FusionPattern pattern(lhs);
152 pattern.insert(pattern.end(), rhs.begin(), rhs.end());
153 return pattern;
154 }
155
EffectiveSize(const FusionPattern & pattern)156 inline int EffectiveSize(const FusionPattern& pattern) {
157 return llvm::count_if(
158 pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
159 }
160
161 // This is an simple shape constraint analysis, which is used to
162 // guide fusion decision (e.g. we only fuse shape-compatible ops).
163 //
164 // Currently, We only consider shape equality propagation based
165 // on the shape constrain traits of elementwise ops (assuming that
166 // implicit shape broadcast is forbidden).
167 class ShapeConstraintAnalysis {
168 public:
ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)169 explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
170 PropagateEquality(op_list);
171 }
172
173 // Returns true is `lhs` and `rhs` are supposed to have same shape.
HasSameShape(Value lhs,Value rhs)174 bool HasSameShape(Value lhs, Value rhs) {
175 return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
176 }
177
178 private:
179 // shape equality propagation based on the shape constrains of
180 // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)181 void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
182 bool converged = true;
183 do {
184 converged = true;
185 auto update = [&](Value lhs, Value rhs) {
186 if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
187 converged = false;
188 impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
189 }
190 };
191 for (Operation* op : op_list) {
192 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
193 if (!op_fusibility) continue;
194 int numInput = op->getNumOperands();
195 int numOutput = op->getNumResults();
196 // shape equality propagation between inputs.
197 for (int input1 = 0; input1 < numInput; ++input1)
198 for (int input2 = input1 + 1; input2 < numInput; ++input2)
199 if (op_fusibility.inferInputsShapeEquality(input1, input2))
200 update(op->getOperand(input1), op->getOperand(input2));
201
202 // shape equality propagation between outputs.
203 for (int output1 = 0; output1 < numOutput; ++output1)
204 for (int output2 = output1 + 1; output2 < numOutput; ++output2)
205 if (op_fusibility.inferOutputsShapeEquality(output1, output2))
206 update(op->getResult(output1), op->getResult(output2));
207
208 // shape equality propagation between input and output.
209 for (int input = 0; input < numInput; ++input)
210 for (int output = 0; output < numOutput; ++output)
211 if (op_fusibility.inferInputOutputShapeEquality(input, output))
212 update(op->getOperand(input), op->getResult(output));
213 }
214 } while (!converged);
215 }
216
217 // a UnionFind set
218 EquivalenceClasses<ValueWrapper> impl_;
219 };
220
221 // A fusion planner that can propose a fusion plan for a block of ops.
222 // The fusion plan is consisted of a group of fusion patterns.
223 //
224 // Currently all proposed patterns followed xla kLoop/kInput like fusion
225 // templates while are adapted to the fully dynamic shape world.
226 //
227 // kLoop fusion template satifies:
228 // - all ops in the fusion pattern are element-wise.
229 // - all the shapes of outputs of fusion pattern are same, and thus can
230 // fit into a same parallel loop.
231 //
232 // kInput fusion template satifies:
233 // - any op in the fusion pattern is either element-wise or a reduction.
234 // - if a op is a reduction, its output cannot be consumered by other
235 // ops in the same fusion pattern.
236 // - all the effective shapes of outputs of fusion pattern are same.
237 // - For element-wise op, its effective shape is its output shape.
238 // - For reduction op, its effective shape is its operand shape.
239 class FusionPlanner {
240 public:
FusionPlanner(const SmallVectorImpl<Operation * > & op_list)241 explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
242 : op_list_(op_list),
243 shape_analysis_(op_list),
244 cycle_detector_(op_list.size()) {
245 BuildNodeMap();
246 }
247
248 // Returns a fusion plan if success, otherwise none.
Run()249 llvm::Optional<FusionPlan> Run() {
250 // Greedily search connected fusible pattern, and ops belonging to
251 // a same fusion pattern are grouped into a cluster.
252 RunEdgeContractionLoop();
253
254 // After doing edge contraction, each unique cluster having size
255 // more than one represents a potential fusion pattern.
256 // We collect all these clusters and construct a fusion plan.
257 //
258 // Note that the ops in a fusion pattern are in topological ordering.
259 FusionPlan plan;
260 DenseMap<int, int> pattern_ids;
261 for (Operation* op : op_list_) {
262 Cluster* cluster = GetClusterForNode(op);
263 int node_id = cluster->cycles_graph_node_id();
264 if (!IsFusible(op_list_[node_id]) ||
265 EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
266 continue;
267 }
268 if (!pattern_ids.count(node_id)) {
269 int pattern_id = pattern_ids.size();
270 pattern_ids[node_id] = pattern_id;
271 plan.emplace_back();
272 }
273 plan[pattern_ids[node_id]].push_back(op);
274 }
275 return plan;
276 }
277
278 // Returns the op_list this planner operates on.
op_list() const279 const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
280
281 private:
282 // Represent a (partial) fused pattern
283 class Cluster {
284 public:
Cluster(int node_id,FusionPlanner * planner)285 Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
286 const SmallVectorImpl<Operation*>& op_list = planner->op_list();
287 pattern_.push_back(op_list[node_id]);
288 }
289
290 // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)291 void Merge(Cluster* other) {
292 pattern_.insert(pattern_.end(), other->pattern_.begin(),
293 other->pattern_.end());
294 other->pattern_.clear();
295 }
296
297 // The number of nodes in this cluster.
cluster_size() const298 int cluster_size() const { return pattern_.size(); }
299
300 // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const301 int cycles_graph_node_id() const { return node_id_; }
302
303 // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)304 void set_cycles_graph_node_id(int cycles_graph_node_id) {
305 node_id_ = cycles_graph_node_id;
306 }
307
308 // Currently the fused pattern this cluster holds.
fused_pattern()309 const FusionPattern& fused_pattern() { return pattern_; }
310
311 private:
312 // ID of the representative node of this cluster.
313 int node_id_;
314
315 // the fused pattern this cluster holds.
316 FusionPattern pattern_;
317 };
318
319 private:
MakeCluster(int cycles_graph_node_id)320 Cluster* MakeCluster(int cycles_graph_node_id) {
321 cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
322 return cluster_storage_.back().get();
323 }
324
BuildNodeMap()325 void BuildNodeMap() {
326 int num_nodes = op_list_.size();
327 for (int node_id = 0; node_id < num_nodes; ++node_id) {
328 Operation* op = op_list_[node_id];
329 MakeCluster(node_id);
330 op_to_node_id_[op] = node_id;
331 leader_for_node_.insert(node_id);
332 for (Value operand : op->getOperands()) {
333 Operation* operand_op = operand.getDefiningOp();
334 if (operand_op == nullptr) {
335 // skip block argument
336 continue;
337 }
338 auto iter = op_to_node_id_.find(operand_op);
339 assert(iter != op_to_node_id_.end());
340 cycle_detector_.InsertEdge(iter->second, node_id);
341 }
342 }
343 }
344
345 // Returns the cluster contains this op.
GetClusterForNode(Operation * n)346 Cluster* GetClusterForNode(Operation* n) {
347 int id = op_to_node_id_.at(n);
348 id = leader_for_node_.getLeaderValue(id);
349 return cluster_storage_[id].get();
350 }
351
352 // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)353 Cluster* GetClusterForCyclesGraphNode(int node_id) {
354 return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
355 }
356
357 // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)358 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
359 int from = cluster_from->cycles_graph_node_id();
360 int to = cluster_to->cycles_graph_node_id();
361
362 auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
363 if (!optional_merged_node.hasValue()) {
364 llvm::dbgs() << "Could not contract " << from << " -> " << to
365 << " because contracting the edge would create a cycle.";
366 return false;
367 }
368
369 // Merge the clusters.
370 cluster_from->Merge(cluster_to);
371 cluster_from->set_cycles_graph_node_id(*optional_merged_node);
372
373 // Merge the UnionFind Set.
374 leader_for_node_.unionSets(from, to);
375 return true;
376 }
377
378 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)379 bool ForEachEdgeInPostOrder(FnTy fn) {
380 bool changed = false;
381 for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
382 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
383 // Make a copy of the set of successors because we may modify the graph in
384 // TryToContractEdge.
385 std::vector<int32_t> successors_copy =
386 cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
387
388 for (int to : successors_copy) {
389 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
390 bool contracted_edge = fn(cluster_from, cluster_to);
391 changed |= contracted_edge;
392 }
393 }
394
395 return changed;
396 }
397
398 // returns the outputs if two cluster were merged
GetResultsOfFusedPattern(Cluster * from,Cluster * to)399 SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
400 FusionPattern fused_pattern =
401 MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
402 return GetOutputsOfFusionPattern(fused_pattern);
403 }
404
405 // This function check if fusing `from` with `to` is valid and if so perform
406 // the merge. The validity is based on the operations in the clusters and
407 // the compatibility of the shapes of the outputs of the would-be fused
408 // clusters.
409 // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)410 bool TryToContractEdge(Cluster* from, Cluster* to) {
411 int node_to = to->cycles_graph_node_id();
412 int node_from = from->cycles_graph_node_id();
413
414 // Both node_to and node_from should be fusible
415 if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
416 return false;
417 }
418
419 auto op_from_fusibility =
420 dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
421 if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
422 // This op cannot be fused with its consumers.
423 return false;
424 }
425
426 auto op_to_fusibility =
427 dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
428 if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
429 // This op cannot be fused with its operands.
430 return false;
431 }
432
433 // Output shapes of a fusion pattern should be compatible as described in
434 // the document of this class.
435 SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
436 auto get_workload_shape = [](Value v) {
437 Operation* op = v.getDefiningOp();
438 // Block argument
439 if (!op) return v;
440 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
441 // Const value
442 if (!op_fusibility) return v;
443 llvm::Optional<Value> workload =
444 op_fusibility.inferEffectiveWorkloadShape();
445 return workload.hasValue() ? *workload : v;
446 };
447
448 Value ref = get_workload_shape(results[0]);
449 if (!llvm::all_of(results, [&](Value result) {
450 Value val = get_workload_shape(result);
451 return shape_analysis_.HasSameShape(ref, val);
452 })) {
453 return false;
454 }
455
456 return MergeClusters(from, to);
457 }
458
459 // Greedily fuse connected node.
RunEdgeContractionLoop()460 bool RunEdgeContractionLoop() {
461 using std::placeholders::_1;
462 using std::placeholders::_2;
463 return ForEachEdgeInPostOrder(
464 std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
465 }
466
467 const SmallVectorImpl<Operation*>& op_list_;
468
469 // Shape equality checker
470 ShapeConstraintAnalysis shape_analysis_;
471
472 // op -> node_id
473 std::unordered_map<Operation*, int> op_to_node_id_;
474
475 // make sure not introduce cycle after fusion
476 GraphCycles cycle_detector_;
477 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
478
479 // a UnionFind set. Each set represents a (partial) fused pattern
480 // and has a leader as representation.
481 EquivalenceClasses<int32_t> leader_for_node_;
482 };
483
484 struct MhloFusionPass : public MhloFusionPassBase<MhloFusionPass> {
runOnFunctionmlir::mhlo::__anonf3dc631f0111::MhloFusionPass485 void runOnFunction() override {
486 FuncOp func = getFunction();
487 if (!IsTargetFunc(func)) {
488 return;
489 }
490
491 // process each block and do fusion within a block.
492 for (Block& block : func) {
493 SmallVector<Operation*, 4> op_list;
494 for (Operation& op : block) {
495 op_list.push_back(&op);
496 }
497
498 FusionPlanner planner(op_list);
499 llvm::Optional<FusionPlan> plan = planner.Run();
500 if (!plan) {
501 emitError(func.getLoc(), "can't find a fusion plan");
502 signalPassFailure();
503 return;
504 }
505 if (!ApplyFusionPlan(*plan)) {
506 emitError(func.getLoc(), "apply fusion plan failed");
507 signalPassFailure();
508 return;
509 }
510 }
511 }
512
IsTargetFuncmlir::mhlo::__anonf3dc631f0111::MhloFusionPass513 bool IsTargetFunc(FuncOp func) {
514 int num_fusible_ops = 0;
515 bool is_target_func = false;
516 // We only process the function having enough candidates
517 func.walk([&](Operation* op) {
518 num_fusible_ops +=
519 static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
520 is_target_func = (num_fusible_ops > 1);
521 // early stop
522 if (is_target_func) return WalkResult::interrupt();
523 return WalkResult::advance();
524 });
525 return is_target_func;
526 }
527
ApplyFusionPlanmlir::mhlo::__anonf3dc631f0111::MhloFusionPass528 bool ApplyFusionPlan(const FusionPlan& plan) {
529 for (const FusionPattern& pattern : plan) {
530 OpBuilder b(pattern.back());
531
532 SmallVector<Location, 4> locations;
533 locations.reserve(pattern.size());
534 for (Operation* op : pattern) {
535 locations.push_back(op->getLoc());
536 }
537 Location fused_loc =
538 FusedLoc::get(pattern.back()->getContext(), locations);
539
540 SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
541 SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
542 SmallVector<Type, 4> output_types;
543 output_types.reserve(outputs.size());
544 for (Value v : outputs) {
545 output_types.push_back(v.getType());
546 }
547
548 FusionOp fusion =
549 b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
550 Region& region = fusion.fused_computation();
551 region.push_back(new Block);
552 Block& block = region.front();
553 for (Operation* op : pattern) {
554 op->moveBefore(&block, block.end());
555 }
556 b.setInsertionPoint(&block, block.end());
557 b.create<mhlo::ReturnOp>(fused_loc, outputs);
558
559 for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
560 Value output = std::get<0>(output_and_result);
561 Value fusion_result = std::get<1>(output_and_result);
562 for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
563 if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
564 }
565 }
566 }
567 return true;
568 }
569 };
570
571 } // namespace
572
createMhloFusionPass()573 std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
574 return std::make_unique<MhloFusionPass>();
575 }
576
577 } // namespace mhlo
578 } // namespace mlir
579