• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
17 
18 #include <algorithm>
19 
20 #include "mlir/Dialect/Shape/IR/Shape.h"  // TF:llvm-project
21 #include "mlir/IR/MLIRContext.h"          // TF:llvm-project
22 #include "mlir/IR/Matchers.h"
23 
24 // This file implements some helper functions and classes used to do fusion
25 // & code generation.
26 
27 namespace mlir {
28 namespace lmhlo {
29 
30 // Returns true if the op is an elementwise unary lmhlo op.
31 // TODO(disc): use fusibility interface
32 // TODO(disc): Unify with disc_supported_list.h and Elementwise Trait
isElementWiseUnary(Operation * op)33 bool isElementWiseUnary(Operation* op) {
34   // clang-format off
35   return isa<
36     lmhlo::AbsOp,
37     lmhlo::CeilOp,
38     lmhlo::ConvertOp,
39     lmhlo::CopyOp,
40     lmhlo::CosOp,
41     lmhlo::ExpOp,
42     lmhlo::FloorOp,
43     lmhlo::IsFiniteOp,
44     lmhlo::LogOp,
45     lmhlo::NegOp,
46     lmhlo::NotOp,
47     lmhlo::RsqrtOp,
48     lmhlo::SignOp,
49     lmhlo::SqrtOp,
50     lmhlo::TanhOp
51   >(op);
52   // clang-format on
53 }
54 
55 // Returns true if the op is an elementwise binary lmhlo op.
56 // TODO(disc): use fusibility interface
isElementWiseBinary(Operation * op)57 bool isElementWiseBinary(Operation* op) {
58   // clang-format off
59   return isa<
60     lmhlo::AddOp,
61     lmhlo::AndOp,
62     lmhlo::CompareOp,
63     lmhlo::DivOp,
64     lmhlo::MaxOp,
65     lmhlo::MinOp,
66     lmhlo::MulOp,
67     lmhlo::OrOp,
68     lmhlo::PowOp,
69     lmhlo::SubOp
70   >(op);
71   // clang-format on
72 }
73 
74 // Returns true if the op is an elementwise lmhlo op.
75 // TODO(disc): use fusibility interface
isElementWise(Operation * op)76 bool isElementWise(Operation* op) {
77   return isElementWiseUnary(op) || isElementWiseBinary(op);
78 }
79 
80 // Returns true if this op is a rank-2 row reduction.
isRank2RowReduction(Operation * op)81 bool isRank2RowReduction(Operation* op) {
82   auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
83   if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
84 
85   int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
86   auto dimensions = reduce_op.dimensions().getValues<int64_t>();
87   return ((*dimensions.begin() == 1) && (rank == 2));
88 }
89 
90 // Returns true if this op is a rank-2 column reduction.
isRank2ColReduction(Operation * op)91 bool isRank2ColReduction(Operation* op) {
92   auto reduce_op = dyn_cast<lmhlo::ReduceOp>(op);
93   if (!reduce_op || reduce_op.dimensions().getNumElements() != 1) return false;
94 
95   int rank = op->getOperand(0).getType().cast<MemRefType>().getRank();
96   auto dimensions = reduce_op.dimensions().getValues<int64_t>();
97   return ((*dimensions.begin() == 0) && (rank == 2));
98 }
99 
100 // Returns true if the op is supported by the downstreaming fusion codegen
101 // engine.
isFusible(Operation * op)102 bool isFusible(Operation* op) {
103   // Only scalar const are supported by the fusion codegen engine a.t.m.
104   if (dyn_cast<lmhlo::ConstOp>(op)) {
105     MemRefType type = op->getOperand(0).getType().cast<MemRefType>();
106     return (type.getRank() == 0);
107   }
108 
109   // All element ops are supported by the fusion codegen engine.
110   if (isElementWise(op)) return true;
111 
112   // Only rank-2 tensor -> rank-1 tensor reduction are supported now.
113   if (isRank2RowReduction(op) || isRank2ColReduction(op)) return true;
114 
115   // clang-format off
116   return isa<
117     lmhlo::BroadcastInDimOp,
118     lmhlo::BroadcastOp,
119     lmhlo::ConcatenateOp,
120     lmhlo::DynamicBroadcastInDimOp,
121     lmhlo::DynamicGatherOp,
122     lmhlo::DynamicIotaOp,
123     lmhlo::DynamicPadOp,
124     lmhlo::DynamicReshapeOp,
125     lmhlo::GatherOp,
126     lmhlo::RealDynamicSliceOp,
127     lmhlo::ReshapeOp,
128     lmhlo::SelectOp,
129     lmhlo::SliceOp,
130     lmhlo::TransposeOp
131   >(op);
132   // clang-format on
133 }
134 
135 // Returns the number of operands that are supposed to be written.
136 // For some ops (e.g. lmhlo ops), some operands are the output memrefs
137 // Thus these operands are supposed to be updated.
getNumResultOperands(Operation * op)138 int getNumResultOperands(Operation* op) {
139   if (!isa<LmhloOp>(op)) {
140     return 0;
141   }
142 
143   auto isWritable = [&](Value operand) -> bool {
144     llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
145     MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
146     // Suppose that operands of op without `MemoryEffectOpInterface` are
147     // readonly.
148     if (!interface) return false;
149 
150     interface.getEffectsOnValue(operand, effects);
151     return llvm::any_of(
152         effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
153           return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
154         });
155   };
156 
157   return llvm::count_if(op->getOperands(),
158                         [&](Value v) { return isWritable(v); });
159 }
160 
161 // Returns data users of the value and its aliases (e.g. memref.cast).
162 // Here non-data users means DimOp, DeallocOp and ShapeOfOp.
getValueUsers(Value v)163 SmallVector<Operation*, 4> getValueUsers(Value v) {
164   SmallVector<Operation*, 4> users;
165   SmallVector<Value, 4> worklist;
166   worklist.push_back(v);
167   while (!worklist.empty()) {
168     Value curr = worklist.back();
169     worklist.pop_back();
170     for (Operation* user : curr.getUsers()) {
171       // Skip non-data users
172       if (isa<memref::DimOp, memref::DeallocOp, shape::ShapeOfOp>(user)) {
173         continue;
174       }
175       // alias value
176       if (isa<memref::CastOp>(user)) {
177         worklist.push_back(user->getResult(0));
178       } else {
179         users.push_back(user);
180       }
181     }
182   }
183   return users;
184 }
185 
186 // Create a new fusion pattern from a single op.
FusionPattern(Operation * op)187 FusionPattern::FusionPattern(Operation* op) {
188   op_list_.push_back(op);
189   if (isRank2RowReduction(op)) {
190     fusion_type_ = FusionType::kRowReduction;
191   } else if (isRank2ColReduction(op)) {
192     fusion_type_ = FusionType::kColReduction;
193   } else if (mlir::lmhlo::isFusible(op)) {
194     fusion_type_ = FusionType::kLoop;
195   } else {
196     fusion_type_ = FusionType::kNone;
197   }
198   dominant_op_ = op;
199   calculateOperandsAndResults();
200 }
201 
202 // Create a new fusion pattern from the ops inside the lmhlo fusion op.
FusionPattern(lmhlo::FusionOp op)203 FusionPattern::FusionPattern(lmhlo::FusionOp op) {
204   for (Operation& op : op.region().getBlocks().front()) {
205     op_list_.push_back(&op);
206   }
207 
208   // Figure out fusion type and dominant op for the fusion pattern.
209   for (Operation* op : op_list_) {
210     if (isRank2RowReduction(op)) {
211       fusion_type_ = FusionType::kRowReduction;
212       dominant_op_ = op;
213     } else if (isRank2ColReduction(op)) {
214       if (fusion_type_ != FusionType::kRowReduction) {
215         fusion_type_ = FusionType::kColReduction;
216         dominant_op_ = op;
217       }
218     } else if (lmhlo::isFusible(op)) {
219       // Ignore if already a kRowReduction or kColReduction, otherwise update
220       // the fusion type to kLoop and dominant op to current op. This supposes
221       // that the last op inside the block is a valid candidate dominant op if
222       // the fusion pattern is a kLoop.
223       if (fusion_type_ == FusionType::kNone ||
224           fusion_type_ == FusionType::kLoop) {
225         fusion_type_ = FusionType::kLoop;
226         dominant_op_ = op;
227       }
228     } else if (!isa<lmhlo::TerminatorOp>(op)) {
229       // Not a supported fusionOp, early stop.
230       fusion_type_ = FusionType::kNone;
231       dominant_op_ = nullptr;
232       break;
233     }
234   }
235 
236   if (isFusible()) calculateOperandsAndResults();
237 }
238 
239 // Create a new fusion pattern from a valid fusion op list.
FusionPattern(SmallVectorImpl<Operation * > & op_list)240 FusionPattern::FusionPattern(SmallVectorImpl<Operation*>& op_list)
241     : op_list_(op_list.begin(), op_list.end()) {
242   calculateOperandsAndResults();
243 }
244 
245 // Returns true if two fusion patterns can be merged into one bigger fusion
246 // pattern.
isMergeable(FusionPattern & other)247 bool FusionPattern::isMergeable(FusionPattern& other) {
248   if (!this->isFusible() || !other.isFusible()) return false;
249   return true;
250 }
251 
252 // Merges two fusion patterns and returns the merged pattern. The original
253 // pattern remains unmodified.
merge(FusionPattern & other)254 FusionPattern FusionPattern::merge(FusionPattern& other) {
255   assert(isMergeable(other));
256   FusionOpList new_op_list = op_list_;
257   new_op_list.insert(new_op_list.end(), other.getOpList().begin(),
258                      other.getOpList().end());
259   FusionPattern new_fusion_pattern{new_op_list};
260 
261   FusionType newType = FusionType::kLoop;
262   Operation* newDominant = getDominantOp();
263 
264   // kRowReduction + (kRowReduction | kColReduction | kLoop) = kRowReduction
265   // kColReduction + (kColReduction | kLoop) = kColReduction
266   // kLoop + kLoop = kLoop
267   if (getFusionType() == FusionType::kRowReduction ||
268       other.getFusionType() == FusionType::kRowReduction) {
269     newType = FusionType::kRowReduction;
270     if (getFusionType() != FusionType::kRowReduction)
271       newDominant = other.getDominantOp();
272   } else if (getFusionType() == FusionType::kColReduction ||
273              other.getFusionType() == FusionType::kColReduction) {
274     newType = FusionType::kColReduction;
275     if (getFusionType() != FusionType::kColReduction)
276       newDominant = other.getDominantOp();
277   }
278 
279   new_fusion_pattern.setDominantOp(newDominant);
280   new_fusion_pattern.setFusionType(newType);
281   return new_fusion_pattern;
282 }
283 
284 // Merges two fusion patterns and returns the merged pattern. Replaces the
285 // original pattern with new merged pattern.
mergeInplace(FusionPattern & other)286 FusionPattern& FusionPattern::mergeInplace(FusionPattern& other) {
287   *this = merge(other);
288   return *this;
289 }
290 
291 // Returns the effective size (e.g. not counting const ops) of the ops this
292 // fusion pattern contains.
effectiveSize()293 int FusionPattern::effectiveSize() {
294   return llvm::count_if(
295       op_list_, [](Operation* op) { return !matchPattern(op, m_Constant()); });
296 }
297 
298 // Sorts the ops inside the fusion pattern according to the keys provided.
sortFusionOpListBy(DenseMap<Operation *,int> & op_to_idx)299 void FusionPattern::sortFusionOpListBy(DenseMap<Operation*, int>& op_to_idx) {
300   std::sort(op_list_.begin(), op_list_.end(),
301             [&](Operation* lhs, Operation* rhs) {
302               return op_to_idx[lhs] < op_to_idx[rhs];
303             });
304 }
305 
306 // Calculates the inputs and outputs of the fusion pattern.
calculateOperandsAndResults()307 void FusionPattern::calculateOperandsAndResults() {
308   DenseSet<Value> input_set;
309   DenseSet<Value> result_set;
310   DenseSet<Value> internal_result_set;
311   DenseSet<Operation*> op_set(op_list_.begin(), op_list_.end());
312 
313   DenseMap<Value, Operation*> last_writer;
314   for (Operation* op : op_list_) {
315     int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
316     for (Value v : op->getOperands().drop_front(num_input_operand)) {
317       bool inserted = last_writer.try_emplace(v, op).second;
318       (void)inserted;
319       assert(inserted);
320 
321       bool has_external_user = false;
322       for (Operation* user : getValueUsers(v)) {
323         if (!op_set.contains(user)) {
324           has_external_user = true;
325           break;
326         }
327       }
328 
329       if (has_external_user) {
330         results_.push_back(v);
331         root_ops_.push_back(op);
332       } else {
333         internal_results_.push_back(v);
334       }
335     }
336   }
337 
338   for (Operation* op : op_list_) {
339     int num_input_operand = op->getNumOperands() - getNumResultOperands(op);
340     for (Value value : op->getOperands().take_front(num_input_operand)) {
341       if (last_writer.find(value) != last_writer.end()) {
342         // skip if defining op is in the pattern
343         continue;
344       }
345       input_set.insert(value);
346     }
347   }
348 
349   for (Value v : input_set) operands_.push_back(v);
350 }
351 
352 // Supports using EquivalenceClasses for Value
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)353 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
354   auto lhs_value = lhs.getValue().getAsOpaquePointer();
355   auto rhs_value = rhs.getValue().getAsOpaquePointer();
356   return lhs_value < rhs_value;
357 }
358 
359 // shape equality propagation based on the shape constrains of
360 // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)361 void ShapeConstraintAnalysis::PropagateEquality(
362     const SmallVectorImpl<Operation*>& op_list) {
363   bool converged = true;
364   do {
365     converged = true;
366     auto update = [&](Value lhs, Value rhs,
367                       EquivalenceClasses<ValueWrapper>& impl) {
368       if (!impl.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
369         converged = false;
370         impl.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
371       }
372     };
373     for (Operation* op : op_list) {
374       int num_operand = op->getNumOperands();
375       // Propagates same num_elements equality, and shape equality
376       if (isElementWise(op)) {
377         Value lhs = op->getOperand(0);
378         for (Value rhs : op->getOperands().drop_front()) {
379           update(lhs, rhs, same_num_elements_impl_);
380           update(lhs, rhs, same_shape_impl_);
381         }
382       }
383       // Propagates same num_elements equality, not shape equality
384       if (isa<lmhlo::DynamicReshapeOp, lmhlo::ReshapeOp, lmhlo::TransposeOp>(
385               op)) {
386         Value input = op->getOperand(0);
387         // The last operand is the output memref by design
388         Value output = op->getOperand(num_operand - 1);
389         update(input, output, same_num_elements_impl_);
390       }
391     }
392   } while (!converged);
393 }
394 
395 }  // namespace lmhlo
396 }  // namespace mlir
397