1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
17
18 #include <algorithm>
19
20 #include "mlir/Dialect/Shape/IR/Shape.h" // TF:llvm-project
21 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
22 #include "mlir/IR/Matchers.h"
23
24 // This file implements some helper functions and classes used to do fusion
25 // & code generation.
26
27 namespace mlir {
28 namespace lmhlo {
29
30 // Returns true if the op is an elementwise unary lmhlo op.
31 // TODO(disc): use fusibility interface
32 // TODO(disc): Unify with disc_supported_list.h and Elementwise Trait
isElementWiseUnary(Operation * op)33 bool isElementWiseUnary(Operation* op) {
34 // clang-format off
35 return isa<
36 lmhlo::AbsOp,
37 lmhlo::CeilOp,
38 lmhlo::ConvertOp,
39 lmhlo::CopyOp,
40 lmhlo::CosOp,
41 lmhlo::ExpOp,
42 lmhlo::FloorOp,
43 lmhlo::IsFiniteOp,
44 lmhlo::LogOp,
45 lmhlo::NegOp,
46 lmhlo::NotOp,
47 lmhlo::RsqrtOp,
48 lmhlo::SignOp,
49 lmhlo::SqrtOp,
50 lmhlo::TanhOp
51 >(op);
52 // clang-format on
53 }
54
55 // Returns true if the op is an elementwise binary lmhlo op.
56 // TODO(disc): use fusibility interface
isElementWiseBinary(Operation * op)57 bool isElementWiseBinary(Operation* op) {
58 // clang-format off
59 return isa<
60 lmhlo::AddOp,
61 lmhlo::AndOp,
62 lmhlo::CompareOp,
63 lmhlo::DivOp,
64 lmhlo::MaxOp,
65 lmhlo::MinOp,
66 lmhlo::MulOp,
67 lmhlo::OrOp,
68 lmhlo::PowOp,
69 lmhlo::SubOp
70 >(op);
71 // clang-format on
72 }
73
74 // Returns true if the op is an elementwise lmhlo op.
75 // TODO(disc): use fusibility interface
isElementWise(Operation * op)76 bool isElementWise(Operation* op) {
77 return isElementWiseUnary(op) || isElementWiseBinary(op);
78 }
79
80 // Returns true if this op is a rank-2 row reduction.
isRank2RowReduction(Operation * op)81 bool isRank2RowReduction(Operation* op) {
82 auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
83 if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
84
85 int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
86 auto dimensions = reduce_op.dimensions().getValues<int64_t>();
87 return ((*dimensions.begin() == 1) && (rank == 2));
88 }
89
90 // Returns true if this op is a rank-2 column reduction.
isRank2ColReduction(Operation * op)91 bool isRank2ColReduction(Operation* op) {
92 auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
93 if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
94
95 int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
96 auto dimensions = reduce_op.dimensions().getValues<int64_t>();
97 return ((*dimensions.begin() == 0) && (rank == 2));
98 }
99
100 // Returns true if the op is supported by the downstreaming fusion codegen
101 // engine.
isFusible(Operation * op)102 bool isFusible(Operation* op) {
103 // Only scalar const are supported by the fusion codegen engine a.t.m.
104 if (dyn_cast<lmhlo::ConstOp>(op)) {
105 MemRefType type = op->getOperand(0).getType().cast<MemRefType>();
106 return (type.getRank() == 0);
107 }
108
109 // All element ops are supported by the fusion codegen engine.
110 if (isElementWise(op)) return true;
111
112 // Only rank-2 tensor -> rank-1 tensor reduction are supported now.
113 if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true;
114
115 // clang-format off
116 return isa<
117 lmhlo::BroadcastInDimOp,
118 lmhlo::BroadcastOp,
119 lmhlo::ConcatenateOp,
120 lmhlo::DynamicBroadcastInDimOp,
121 lmhlo::DynamicGatherOp,
122 lmhlo::DynamicIotaOp,
123 lmhlo::DynamicPadOp,
124 lmhlo::DynamicReshapeOp,
125 lmhlo::GatherOp,
126 lmhlo::RealDynamicSliceOp,
127 lmhlo::ReshapeOp,
128 lmhlo::SelectOp,
129 lmhlo::SliceOp,
130 lmhlo::TransposeOp
131 >(op);
132 // clang-format on
133 }
134
135 // Returns the number of operands that are supposed to be written.
136 // For some ops (e.g. lmhlo ops), some operands are the output memrefs
137 // Thus these operands are supposed to be updated.
getNumResultOperands(Operation * op)138 int getNumResultOperands(Operation* op) {
139 if (!isa<LmhloOp>(op)) {
140 return 0;
141 }
142
143 auto isWritable = [&](Value operand) -> bool {
144 llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
145 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
146 // Suppose that operands of op without `MemoryEffectOpInterface` are
147 // readonly.
148 if (!interface) return false;
149
150 interface.getEffectsOnValue(operand, effects);
151 return llvm::any_of(
152 effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
153 return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
154 });
155 };
156
157 return llvm::count_if(op->getOperands(),
158 [&](Value v) { return isWritable(v); });
159 }
160
161 // Returns data users of the value and its aliases (e.g. memref.cast).
162 // Here non-data users means DimOp, DeallocOp and ShapeOfOp.
getValueUsers(Value v)163 SmallVector<Operation*, 4> getValueUsers(Value v) {
164 SmallVector<Operation*, 4> users;
165 SmallVector<Value, 4> worklist;
166 worklist.push_back(v);
167 while (!worklist.empty()) {
168 Value curr = worklist.back();
169 worklist.pop_back();
170 for (Operation* user : curr.getUsers()) {
171 // Skip non-data users
172 if (isa<memref::DimOp, memref::DeallocOp, shape::ShapeOfOp>(user)) {
173 continue;
174 }
175 // alias value
176 if (isa<memref::CastOp>(user)) {
177 worklist.push_back(user->getResult(0));
178 } else {
179 users.push_back(user);
180 }
181 }
182 }
183 return users;
184 }
185
186 // Create a new fusion pattern from a single op.
FusionPattern(Operation * op)187 FusionPattern::FusionPattern(Operation* op) {
188 op_list_.push_back(op);
189 if (isRank2RowReduction(op)) {
190 fusion_type_ = FusionType::kRowReduction;
191 } else if (isRank2ColReduction(op)) {
192 fusion_type_ = FusionType::kColReduction;
193 } else if (mlir::lmhlo::isFusible(op)) {
194 fusion_type_ = FusionType::kLoop;
195 } else {
196 fusion_type_ = FusionType::kNone;
197 }
198 dominant_op_ = op;
199 calculateOperandsAndResults();
200 }
201
202 // Create a new fusion pattern from the ops inside the lmhlo fusion op.
FusionPattern(lmhlo::FusionOp op)203 FusionPattern::FusionPattern(lmhlo::FusionOp op) {
204 for (Operation& op : op.region().getBlocks().front()) {
205 op_list_.push_back(&op);
206 }
207
208 // Figure out fusion type and dominant op for the fusion pattern.
209 for (Operation* op : op_list_) {
210 if (isRank2RowReduction(op)) {
211 fusion_type_ = FusionType::kRowReduction;
212 dominant_op_ = op;
213 } else if (isRank2ColReduction(op)) {
214 if (fusion_type_ != FusionType::kRowReduction) {
215 fusion_type_ = FusionType::kColReduction;
216 dominant_op_ = op;
217 }
218 } else if (lmhlo::isFusible(op)) {
219 // Ignore if already a kRowReduction or kColReduction, otherwise update
220 // the fusion type to kLoop and dominant op to current op. This supposes
221 // that the last op inside the block is a valid candidate dominant op if
222 // the fusion pattern is a kLoop.
223 if (fusion_type_ == FusionType::kNone ||
224 fusion_type_ == FusionType::kLoop) {
225 fusion_type_ = FusionType::kLoop;
226 dominant_op_ = op;
227 }
228 } else if (!isa<lmhlo::TerminatorOp>(op)) {
229 // Not a supported fusionOp, early stop.
230 fusion_type_ = FusionType::kNone;
231 dominant_op_ = nullptr;
232 break;
233 }
234 }
235
236 if (isFusible()) calculateOperandsAndResults();
237 }
238
239 // Create a new fusion pattern from a valid fusion op list.
FusionPattern(SmallVectorImpl<Operation * > & op_list)240 FusionPattern::FusionPattern(SmallVectorImpl<Operation*>& op_list)
241 : op_list_(op_list.begin(), op_list.end()) {
242 calculateOperandsAndResults();
243 }
244
245 // Returns true if two fusion patterns can be merged into one bigger fusion
246 // pattern.
isMergeable(FusionPattern & other)247 bool FusionPattern::isMergeable(FusionPattern& other) {
248 if (!this->isFusible() || !other.isFusible()) return false;
249 return true;
250 }
251
252 // Merges two fusion patterns and returns the merged pattern. The original
253 // pattern remains unmodified.
merge(FusionPattern & other)254 FusionPattern FusionPattern::merge(FusionPattern& other) {
255 assert(isMergeable(other));
256 FusionOpList new_op_list = op_list_;
257 new_op_list.insert(new_op_list.end(), other.getOpList().begin(),
258 other.getOpList().end());
259 FusionPattern new_fusion_pattern{new_op_list};
260
261 FusionType newType = FusionType::kLoop;
262 Operation* newDominant = getDominantOp();
263
264 // kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction
265 // kColReduction + (kColReduction | kLoop) = kColReduction
266 // kLoop + kLoop = kLoop
267 if (getFusionType() == FusionType::kRowReduction ||
268 other.getFusionType() == FusionType::kRowReduction) {
269 newType = FusionType::kRowReduction;
270 if (getFusionType() != FusionType::kRowReduction)
271 newDominant = other.getDominantOp();
272 } else if (getFusionType() == FusionType::kColReduction ||
273 other.getFusionType() == FusionType::kColReduction) {
274 newType = FusionType::kColReduction;
275 if (getFusionType() != FusionType::kColReduction)
276 newDominant = other.getDominantOp();
277 }
278
279 new_fusion_pattern.setDominantOp(newDominant);
280 new_fusion_pattern.setFusionType(newType);
281 return new_fusion_pattern;
282 }
283
284 // Merges two fusion patterns and returns the merged pattern. Replaces the
285 // original pattern with new merged pattern.
mergeInplace(FusionPattern & other)286 FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) {
287 *this = merge(other);
288 return *this;
289 }
290
291 // Returns the effective size (e.g. not counting const ops) of the ops this
292 // fusion pattern contains.
effectiveSize()293 int FusionPattern::effectiveSize() {
294 return llvm::count_if(
295 op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); });
296 }
297
298 // Sorts the ops inside the fusion pattern according to the keys provided.
sortFusionOpListBy(DenseMap<Operation *,int> & op_to_idx)299 void FusionPattern::sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx) {
300 std::sort(op_list_.begin(), op_list_.end(),
301 [&](Operation* lhs, Operation* rhs) {
302 return op_to_idx[lhs] < op_to_idx[rhs];
303 });
304 }
305
306 // Calculates the inputs and outputs of the fusion pattern.
calculateOperandsAndResults()307 void FusionPattern::calculateOperandsAndResults() {
308 DenseSet<Value> input_set;
309 DenseSet<Value> result_set;
310 DenseSet<Value> internal_result_set;
311 DenseSet<Operation*> op_set(op_list_.begin(), op_list_.end());
312
313 DenseMap<Value, Operation*> last_writer;
314 for (Operation* op : op_list_) {
315 int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
316 for (Value v : op->getOperands().drop_front(num_input_operand)) {
317 bool inserted = last_writer.try_emplace(v, op).second;
318 (void)inserted;
319 assert(inserted);
320
321 bool has_external_user = false;
322 for (Operation* user : getValueUsers(v)) {
323 if (!op_set.contains(user)) {
324 has_external_user = true;
325 break;
326 }
327 }
328
329 if (has_external_user) {
330 results_.push_back(v);
331 root_ops_.push_back(op);
332 } else {
333 internal_results_.push_back(v);
334 }
335 }
336 }
337
338 for (Operation* op : op_list_) {
339 int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
340 for (Value value : op->getOperands().take_front(num_input_operand)) {
341 if (last_writer.find(value) != last_writer.end()) {
342 // skip if defining op is in the pattern
343 continue;
344 }
345 input_set.insert(value);
346 }
347 }
348
349 for (Value v : input_set) operands_.push_back(v);
350 }
351
352 // Supports using EquivalenceClasses for Value
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)353 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
354 auto lhs_value = lhs.getValue().getAsOpaquePointer();
355 auto rhs_value = rhs.getValue().getAsOpaquePointer();
356 return lhs_value < rhs_value;
357 }
358
359 // shape equality propagation based on the shape constrains of
360 // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)361 void ShapeConstraintAnalysis::PropagateEquality(
362 const SmallVectorImpl<Operation*>& op_list) {
363 bool converged = true;
364 do {
365 converged = true;
366 auto update = [&](Value lhs, Value rhs,
367 EquivalenceClasses<ValueWrapper>& impl) {
368 if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
369 converged = false;
370 impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
371 }
372 };
373 for (Operation* op : op_list) {
374 int num_operand = op->getNumOperands();
375 // Propagates same num_elements equality, and shape equality
376 if (isElementWise(op)) {
377 Value lhs = op->getOperand(0);
378 for (Value rhs : op->getOperands().drop_front()) {
379 update(lhs, rhs, same_num_elements_impl_);
380 update(lhs, rhs, same_shape_impl_);
381 }
382 }
383 // Propagates same num_elements equality, not shape equality
384 if (isa<lmhlo::DynamicReshapeOp, lmhlo::ReshapeOp, lmhlo::TransposeOp>(
385 op)) {
386 Value input = op->getOperand(0);
387 // The last operand is the output memref by design
388 Value output = op->getOperand(num_operand - 1);
389 update(input, output, same_num_elements_impl_);
390 }
391 }
392 } while (!converged);
393 }
394
395 } // namespace lmhlo
396 } // namespace mlir
397