1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "llvm/ADT/EquivalenceClasses.h" 23 #include "llvm/Support/Debug.h" 24 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project 26 27 // This file implements some helper functions and classes used to do fusion 28 // & code generation. 29 30 namespace mlir { 31 namespace lmhlo { 32 33 // kLoop fusion template satisfies: 34 // - all ops in the fusion pattern are element-wise. 35 // - all the shapes of outputs of fusion pattern are same or have same number 36 // of elements, and thus can fit into a same parallel loop. 37 // 38 // kInput fusion template satisfies: 39 // - any op in the fusion pattern is either element-wise or a reduction. 40 // - if a op is a reduction, its output cannot be consumed by other 41 // ops in the same fusion pattern. 42 // - all the effective shapes of outputs of fusion pattern are same. 43 // - For element-wise op, its effective shape is its output shape. 44 // - For reduction op, its effective shape is its operand shape. 45 // - currently our downstreaming codegen engine only support 2d -> 1d tensor 46 // reduction. TODO: lift this limitation. 47 // - 2D row reduction: out[i] = sum({in[i][j] for all j}) 48 // - 2D column reduction: out[j] = sum({in[i][j] for all i}) 49 enum FusionType { 50 // Not a fusion pattern 51 kNone, 52 // kLoop fusion pattern 53 kLoop, 54 // kInput fusion pattern and all reduce ops of the fused pattern are row 55 // reduction 56 kRowReduction, 57 // kInput fusion pattern and all reduce ops of the fused pattern are column 58 // reduction 59 kColReduction, 60 }; 61 62 // Returns true if the op is an elementwise unary lmhlo op. 63 // TODO: use fusibility interface 64 bool isElementWiseUnary(Operation* op); 65 66 // Returns true if the op is an elementwise binary lmhlo op. 67 // TODO: use fusibility interface 68 bool isElementWiseBinary(Operation* op); 69 70 // Returns true if the op is an elementwise lmhlo op. 71 // TODO: use fusibility interface 72 bool isElementWise(Operation* op); 73 74 // Returns true if this op is a rank-2 row reduction. 75 bool isRank2RowReduction(Operation* op); 76 77 // Returns true if this op is a rank-2 column reduction. 78 bool isRank2ColReduction(Operation* op); 79 80 // Returns true if the op is supported by the downstreaming fusion codegen 81 // engine. 82 bool isFusible(Operation* op); 83 84 // Returns the number of operands that are supposed to be written. 85 // For some ops (e.g. lmhlo ops), some operands are the output memrefs 86 // Thus these operands are supposed to be updated. 87 int getNumResultOperands(Operation* op); 88 89 // Returns data users of the value and its aliases (e.g. memref.cast). 90 // Here non-data users means DimOp, DeallocOp and ShapeOfOp. 91 SmallVector<Operation*, 4> getValueUsers(Value v); 92 93 // Represents a list of lmhlo ops that are going to be fused. 94 class FusionPattern { 95 public: 96 using FusionOpList = SmallVector<Operation*, 4>; 97 using FusionValueList = SmallVector<Value, 4>; 98 99 // Create a new fusion pattern from a single op. 100 FusionPattern(Operation* op); 101 102 // Create a new fusion pattern from the ops inside the lmhlo fusion op. 103 FusionPattern(lmhlo::FusionOp op); 104 105 // Returns the op list this fusion pattern represents. getOpList()106 FusionOpList& getOpList() { return op_list_; } 107 108 // Returns the dominant op of this fusion pattern. 109 // For kLoop fusion, a dominant op may be any op that has external users. 110 // For kInput fusion, a dominant op may be a row reduction (if exists), or 111 // a column reduction op. getDominantOp()112 Operation* getDominantOp() { return dominant_op_; } 113 114 // Sets the dominant op to the op provided. setDominantOp(Operation * op)115 void setDominantOp(Operation* op) { dominant_op_ = op; } 116 117 // Returns the fusion kind of the fusion pattern. getFusionType()118 FusionType getFusionType() { return fusion_type_; } 119 120 // Sets the fusion type to the the type provided. setFusionType(FusionType type)121 void setFusionType(FusionType type) { fusion_type_ = type; } 122 123 // Returns true if this a fusible fusion pattern. isFusible()124 bool isFusible() { return getFusionType() != FusionType::kNone; } 125 126 // Returns true if this fusion pattern is a kLoop fusion. isKLoopFusion()127 bool isKLoopFusion() { return getFusionType() == FusionType::kLoop; } 128 129 // Returns true if this fusion pattern is a kInput fusion. isKInputFusion()130 bool isKInputFusion() { 131 return (getFusionType() == FusionType::kRowReduction || 132 getFusionType() == FusionType::kColReduction); 133 } 134 135 // Returns true if two fusion patterns can be merged into one bigger fusion 136 // pattern. 137 bool isMergeable(FusionPattern& other); 138 139 // Merges two fusion patterns and returns the merged pattern. The original 140 // pattern remains unmodified. 141 FusionPattern merge(FusionPattern& other); 142 143 // Merges two fusion patterns and returns the merged pattern. Replaces the 144 // original pattern with new merged pattern. 145 FusionPattern& mergeInplace(FusionPattern& other); 146 147 // Returns values that are consumed by the lmhlo ops inside the fusion 148 // pattern. getOperands()149 FusionValueList& getOperands() { return operands_; } 150 151 // Returns values that are outputs of any lmhlo op in the fused pattern and 152 // have consumers outside the fusion pattern. getResults()153 FusionValueList& getResults() { return results_; } 154 155 // Returns values that are outputs of any lmhlo op in the fused pattern and 156 // have consumers outside the fusion pattern. getRootOps()157 SmallVector<Operation*, 4>& getRootOps() { return root_ops_; } 158 159 // Returns values that are outputs of any lmhlo op in the fused pattern and 160 // are only consumed by the lmhlo ops inside the fused pattern. getInternalResults()161 FusionValueList& getInternalResults() { return internal_results_; } 162 163 // Returns the size of the ops this fusion pattern contains. size()164 int size() { return op_list_.size(); } 165 166 // Returns the effective size (e.g. not counting const ops) of the ops this 167 // fusion pattern contains. 168 int effectiveSize(); 169 170 // Sorts the ops inside the fusion pattern according to the keys provided. 171 void sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx); 172 173 private: 174 FusionPattern(SmallVectorImpl<Operation*>& op_list); 175 176 private: 177 // Calculates the inputs and outputs of the fusion pattern. 178 void calculateOperandsAndResults(); 179 180 private: 181 FusionOpList op_list_; 182 Operation* dominant_op_ = nullptr; 183 FusionType fusion_type_ = FusionType::kNone; 184 FusionValueList operands_; 185 FusionValueList results_; 186 FusionValueList internal_results_; 187 SmallVector<Operation*, 4> root_ops_; 188 }; 189 190 // Represents a list of disjoint fusion patterns for a block. 191 using FusionPlan = std::vector<FusionPattern>; 192 193 using llvm::EquivalenceClasses; 194 195 // Supports using EquivalenceClasses for Value 196 class ValueWrapper { 197 public: ValueWrapper(Value value)198 explicit ValueWrapper(Value value) : value_(std::move(value)) {} 199 getValue()200 Value getValue() const { return value_; } 201 202 bool operator==(const ValueWrapper& rhs) const { 203 return getValue() == rhs.getValue(); 204 } 205 206 private: 207 Value value_; 208 }; 209 210 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs); 211 212 // This is a simple shape constraint analysis, which is used to 213 // guide fusion decision (e.g. we only fuse shape-compatible ops). 214 // 215 // Currently, We only consider shape equality and same-number-elements equality 216 // propagation based on the shape constraint traits of elementwise ops (assuming 217 // that implicit shape broadcast is forbidden). 218 class ShapeConstraintAnalysis { 219 public: ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)220 explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) { 221 PropagateEquality(op_list); 222 } 223 224 // Returns true if `lhs` and `rhs` are supposed to have same shape. HasSameShape(Value lhs,Value rhs)225 bool HasSameShape(Value lhs, Value rhs) { 226 return same_shape_impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs)); 227 } 228 229 // Returns true if `lhs` and `rhs` are supposed to have same number of 230 // elements. HasSameNumElements(Value lhs,Value rhs)231 bool HasSameNumElements(Value lhs, Value rhs) { 232 return same_num_elements_impl_.isEquivalent(ValueWrapper(lhs), 233 ValueWrapper(rhs)); 234 } 235 GetLeaderValueWithSameShape(Value val)236 Value GetLeaderValueWithSameShape(Value val) const { 237 if (same_shape_impl_.findLeader(ValueWrapper(val)) == 238 same_shape_impl_.member_end()) { 239 return nullptr; 240 } 241 return same_shape_impl_.getLeaderValue(ValueWrapper(val)).getValue(); 242 } 243 244 private: 245 // shape equality propagation based on the shape constrains of 246 // elementwise ops. 247 void PropagateEquality(const SmallVectorImpl<Operation*>& op_list); 248 249 // a UnionFind set 250 EquivalenceClasses<ValueWrapper> same_shape_impl_; 251 EquivalenceClasses<ValueWrapper> same_num_elements_impl_; 252 }; 253 254 } // namespace lmhlo 255 } // namespace mlir 256 257 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_FUSION_UTILS_H_ 258