• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "llvm/ADT/EquivalenceClasses.h"
23 #include "llvm/Support/Debug.h"
24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
26 
27 // This file implements some helper functions and classes used to do fusion
28 // & code generation.
29 
30 namespace mlir {
31 namespace lmhlo {
32 
33 // kLoop fusion template satisfies:
34 //   - all ops in the fusion pattern are element-wise.
35 //   - all the shapes of outputs of fusion pattern are same or have same number
36 //   of elements, and thus can fit into a same parallel loop.
37 //
38 // kInput fusion template satisfies:
39 //   - any op in the fusion pattern is either element-wise or a reduction.
40 //   - if a op is a reduction, its output cannot be consumed by other
41 //     ops in the same fusion pattern.
42 //   - all the effective shapes of outputs of fusion pattern are same.
43 //     - For element-wise op, its effective shape is its output shape.
44 //     - For reduction op, its effective shape is its operand shape.
45 //   - currently our downstreaming codegen engine only support 2d -> 1d tensor
46 //   reduction. TODO: lift this limitation.
47 //     - 2D row reduction: out[i] = sum({in[i][j] for all j})
48 //     - 2D column reduction: out[j] = sum({in[i][j] for all i})
49 enum FusionType {
50   // Not a fusion pattern
51   kNone,
52   // kLoop fusion pattern
53   kLoop,
54   // kInput fusion pattern and all reduce ops of the fused pattern are row
55   // reduction
56   kRowReduction,
57   // kInput fusion pattern and all reduce ops of the fused pattern are column
58   // reduction
59   kColReduction,
60 };
61 
62 // Returns true if the op is an elementwise unary lmhlo op.
63 // TODO: use fusibility interface
64 bool isElementWiseUnary(Operation* op);
65 
66 // Returns true if the op is an elementwise binary lmhlo op.
67 // TODO: use fusibility interface
68 bool isElementWiseBinary(Operation* op);
69 
70 // Returns true if the op is an elementwise lmhlo op.
71 // TODO: use fusibility interface
72 bool isElementWise(Operation* op);
73 
74 // Returns true if this op is a rank-2 row reduction.
75 bool isRank2RowReduction(Operation* op);
76 
77 // Returns true if this op is a rank-2 column reduction.
78 bool isRank2ColReduction(Operation* op);
79 
80 // Returns true if the op is supported by the downstreaming fusion codegen
81 // engine.
82 bool isFusible(Operation* op);
83 
84 // Returns the number of operands that are supposed to be written.
85 // For some ops (e.g. lmhlo ops), some operands are the output memrefs
86 // Thus these operands are supposed to be updated.
87 int getNumResultOperands(Operation* op);
88 
89 // Returns data users of the value and its aliases (e.g. memref.cast).
90 // Here non-data users means DimOp, DeallocOp and ShapeOfOp.
91 SmallVector<Operation*, 4> getValueUsers(Value v);
92 
93 // Represents a list of lmhlo ops that are going to be fused.
94 class FusionPattern {
95  public:
96   using FusionOpList = SmallVector<Operation*, 4>;
97   using FusionValueList = SmallVector<Value, 4>;
98 
99   // Create a new fusion pattern from a single op.
100   FusionPattern(Operation* op);
101 
102   // Create a new fusion pattern from the ops inside the lmhlo fusion op.
103   FusionPattern(lmhlo::FusionOp op);
104 
105   // Returns the op list this fusion pattern represents.
getOpList()106   FusionOpList& getOpList() { return op_list_; }
107 
108   // Returns the dominant op of this fusion pattern.
109   // For kLoop fusion, a dominant op may be any op that has external users.
110   // For kInput fusion, a dominant op may be a row reduction (if exists), or
111   // a column reduction op.
getDominantOp()112   Operation* getDominantOp() { return dominant_op_; }
113 
114   // Sets the dominant op to the op provided.
setDominantOp(Operation * op)115   void setDominantOp(Operation* op) { dominant_op_ = op; }
116 
117   // Returns the fusion kind of the fusion pattern.
getFusionType()118   FusionType getFusionType() { return fusion_type_; }
119 
120   // Sets the fusion type to the the type provided.
setFusionType(FusionType type)121   void setFusionType(FusionType type) { fusion_type_ = type; }
122 
123   // Returns true if this a fusible fusion pattern.
isFusible()124   bool isFusible() { return getFusionType() != FusionType::kNone; }
125 
126   // Returns true if this fusion pattern is a kLoop fusion.
isKLoopFusion()127   bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; }
128 
129   // Returns true if this fusion pattern is a kInput fusion.
isKInputFusion()130   bool isKInputFusion() {
131     return (getFusionType() == FusionType::kRowReduction ||
132             getFusionType() == FusionType::kColReduction);
133   }
134 
135   // Returns true if two fusion patterns can be merged into one bigger fusion
136   // pattern.
137   bool isMergeable(FusionPattern& other);
138 
139   // Merges two fusion patterns and returns the merged pattern. The original
140   // pattern remains unmodified.
141   FusionPattern merge(FusionPattern& other);
142 
143   // Merges two fusion patterns and returns the merged pattern. Replaces the
144   // original pattern with new merged pattern.
145   FusionPattern& mergeInplace(FusionPattern& other);
146 
147   // Returns values that are consumed by the lmhlo ops inside the fusion
148   // pattern.
getOperands()149   FusionValueList& getOperands() { return operands_; }
150 
151   // Returns values that are outputs of any lmhlo op in the fused pattern and
152   // have consumers outside the fusion pattern.
getResults()153   FusionValueList& getResults() { return results_; }
154 
155   // Returns values that are outputs of any lmhlo op in the fused pattern and
156   // have consumers outside the fusion pattern.
getRootOps()157   SmallVector<Operation*, 4>& getRootOps() { return root_ops_; }
158 
159   // Returns values that are outputs of any lmhlo op in the fused pattern and
160   // are only consumed by the lmhlo ops inside the fused pattern.
getInternalResults()161   FusionValueList& getInternalResults() { return internal_results_; }
162 
163   // Returns the size of the ops this fusion pattern contains.
size()164   int size() { return op_list_.size(); }
165 
166   // Returns the effective size (e.g. not counting const ops) of the ops this
167   // fusion pattern contains.
168   int effectiveSize();
169 
170   // Sorts the ops inside the fusion pattern according to the keys provided.
171   void sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx);
172 
173  private:
174   FusionPattern(SmallVectorImpl<Operation*>& op_list);
175 
176  private:
177   // Calculates the inputs and outputs of the fusion pattern.
178   void calculateOperandsAndResults();
179 
180  private:
181   FusionOpList op_list_;
182   Operation* dominant_op_ = nullptr;
183   FusionType fusion_type_ = FusionType::kNone;
184   FusionValueList operands_;
185   FusionValueList results_;
186   FusionValueList internal_results_;
187   SmallVector<Operation*, 4> root_ops_;
188 };
189 
190 // Represents a list of disjoint fusion patterns for a block.
191 using FusionPlan = std::vector<FusionPattern>;
192 
193 using llvm::EquivalenceClasses;
194 
195 // Supports using EquivalenceClasses for Value
196 class ValueWrapper {
197  public:
ValueWrapper(Value value)198   explicit ValueWrapper(Value value) : value_(std::move(value)) {}
199 
getValue()200   Value getValue() const { return value_; }
201 
202   bool operator==(const ValueWrapper& rhs) const {
203     return getValue() == rhs.getValue();
204   }
205 
206  private:
207   Value value_;
208 };
209 
210 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs);
211 
212 // This is a simple shape constraint analysis, which is used to
213 // guide fusion decision (e.g. we only fuse shape-compatible ops).
214 //
215 // Currently, We only consider shape equality and same-number-elements equality
216 // propagation based on the shape constraint traits of elementwise ops (assuming
217 // that implicit shape broadcast is forbidden).
218 class ShapeConstraintAnalysis {
219  public:
ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)220   explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
221     PropagateEquality(op_list);
222   }
223 
224   // Returns true if `lhs` and `rhs` are supposed to have same shape.
HasSameShape(Value lhs,Value rhs)225   bool HasSameShape(Value lhs, Value rhs) {
226     return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
227   }
228 
229   // Returns true if `lhs` and `rhs` are supposed to have same number of
230   // elements.
HasSameNumElements(Value lhs,Value rhs)231   bool HasSameNumElements(Value lhs, Value rhs) {
232     return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs),
233                                                 ValueWrapper(rhs));
234   }
235 
GetLeaderValueWithSameShape(Value val)236   Value GetLeaderValueWithSameShape(Value val) const {
237     if (same_shape_impl_.findLeader(ValueWrapper(val)) ==
238         same_shape_impl_.member_end()) {
239       return nullptr;
240     }
241     return same_shape_impl_.getLeaderValue(ValueWrapper(val)).getValue();
242   }
243 
244  private:
245   // shape equality propagation based on the shape constrains of
246   // elementwise ops.
247   void PropagateEquality(const SmallVectorImpl<Operation*>& op_list);
248 
249   // a UnionFind set
250   EquivalenceClasses<ValueWrapper> same_shape_impl_;
251   EquivalenceClasses<ValueWrapper> same_num_elements_impl_;
252 };
253 
254 }  // namespace lmhlo
255 }  // namespace mlir
256 
257 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_
258