• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17 #include "llvm/Support/Debug.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
22 #include "mlir-hlo/utils/codegen_utils.h"
23 #include "mlir-hlo/utils/placement_utils.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/SCF.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/BlockAndValueMapping.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Support/LogicalResult.h"
36 
37 using mlir::codegen_utils::calcMultiDimIndex;
38 
39 namespace mlir {
40 namespace lmhlo {
41 
42 #define GEN_PASS_CLASSES
43 #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
44 
45 namespace {
46 
47 template <typename LHLO_OpTy>
elemwiseLowerHelper(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)48 LogicalResult elemwiseLowerHelper(
49     OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
50     const ShapeConstraintAnalysis* shape_constraint_analysis) {
51   if (!isa<LHLO_OpTy>(op) || !op->hasTrait<mlir::OpTrait::Elementwise>())
52     return failure();
53 
54   Value result_memref = cast<LmhloOp>(op).getResultBuffer();
55   Value memref = result_memref;
56   if (shape_constraint_analysis) {
57     Value leader_memref =
58         shape_constraint_analysis->GetLeaderValueWithSameShape(result_memref);
59     if (leader_memref != nullptr) memref = leader_memref;
60   }
61   // TODO(disc): Replace with memref.Delinearize
62   auto multidim_index = calcMultiDimIndex(b, loc, output_linear_index, memref);
63   SmallVector<Value, 4> operand_values;
64   for (Value operand_memref : op->getOperands().drop_back()) {
65     Value operand_data = createLoadOrUseCachedValue(
66         loc, &b, operand_memref, multidim_index, b.saveInsertionPoint());
67     operand_values.push_back(operand_data);
68   }
69   auto res = HloOpToStdScalarOp::map<LHLO_OpTy>(
70       llvm::cast<LHLO_OpTy>(op),
71       result_memref.getType().cast<MemRefType>().getElementType(),
72       operand_values, &b);
73   createOffsetStore(b, loc, res, result_memref, output_linear_index);
74   return success();
75 }
76 
77 template <typename LHLO_OpTy>
miscLowerHelper(OpBuilder & b,Location loc,Operation * opaque_op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)78 LogicalResult miscLowerHelper(
79     OpBuilder& b, Location loc, Operation* opaque_op, Value output_linear_index,
80     const ShapeConstraintAnalysis* shape_constraint_analysis) {
81   LHLO_OpTy op = dyn_cast<LHLO_OpTy>(opaque_op);
82   if (!op) return failure();
83   Value result_memref = cast<LmhloOp>(&*op).getResultBuffer();
84   Value memref = result_memref;
85   if (shape_constraint_analysis) {
86     Value leader_memref =
87         shape_constraint_analysis->GetLeaderValueWithSameShape(result_memref);
88     if (leader_memref != nullptr) {
89       memref = leader_memref;
90     }
91   }
92   llvm::SmallVector<Value, 4> output_multidim_index =
93       calcMultiDimIndex(b, loc, output_linear_index, memref);
94   Value operand_data = elementalLower(&b, loc, op, output_multidim_index,
95                                       /*check_cache=*/true);
96   createOffsetStore(b, loc, operand_data, result_memref, output_linear_index);
97   return success();
98 }
99 
100 template <typename First>
elemwiseLowerHelperOr(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)101 LogicalResult elemwiseLowerHelperOr(
102     OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
103     const ShapeConstraintAnalysis* shape_constraint_analysis) {
104   return elemwiseLowerHelper<First>(b, loc, op, output_linear_index,
105                                     shape_constraint_analysis);
106 }
107 
108 template <typename First, typename Second, typename... Rest>
elemwiseLowerHelperOr(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)109 LogicalResult elemwiseLowerHelperOr(
110     OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
111     const ShapeConstraintAnalysis* shape_constraint_analysis) {
112   return success(
113       succeeded(elemwiseLowerHelperOr<First>(b, loc, op, output_linear_index,
114                                              shape_constraint_analysis)) ||
115       succeeded(elemwiseLowerHelperOr<Second, Rest...>(
116           b, loc, op, output_linear_index, shape_constraint_analysis)));
117 }
118 
lowerHelper(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)119 LogicalResult lowerHelper(
120     OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
121     const ShapeConstraintAnalysis* shape_constraint_analysis) {
122   if (succeeded(
123           elemwiseLowerHelperOr<
124 #define GET_SUPPORTED_OP_LIST
125 #include "mlir-hlo/utils/disc_supported_list.h.inc"
126               >(b, loc, op, output_linear_index, shape_constraint_analysis)) ||
127       // TODO(disc): Upstream is on the way for more Ops
128       succeeded(miscLowerHelper<RealDynamicSliceOp>(
129           b, loc, op, output_linear_index, shape_constraint_analysis)) ||
130       succeeded(miscLowerHelper<DynamicBroadcastInDimOp>(
131           b, loc, op, output_linear_index, shape_constraint_analysis)) ||
132       succeeded(miscLowerHelper<BroadcastInDimOp>(
133           b, loc, op, output_linear_index, shape_constraint_analysis))) {
134     return success();
135   }
136   return failure();
137 }
138 
139 // we don't do inbound check for kLoop Schedule
140 // LoopSplit pass will do this.
141 //
142 /* %num_elements = ElementsIn(root_shape)
143  * loop.for %idx = 0 to %num_elements step 1 {
144  *   %multidim_indices_0..n = getMultidimIndices(%idx);
145  *   %operand_0 = load %operand0[]
146  *   %operand_1 = load %operand1[]
147  *   emit calculation..
148  * }
149  */
lowerWithScheduleLoop(ArrayRef<Operation * > root_ops,Operation * dominant_op,Block * parent=nullptr,bool non_fusion=false,bool parallel_loop=true,const ShapeConstraintAnalysis * shape_constraint_analysis=nullptr)150 LogicalResult lowerWithScheduleLoop(
151     ArrayRef<Operation*> root_ops, Operation* dominant_op,
152     Block* parent = nullptr, bool non_fusion = false, bool parallel_loop = true,
153     const ShapeConstraintAnalysis* shape_constraint_analysis = nullptr) {
154   const auto loc = dominant_op->getLoc();
155   OpBuilder b(root_ops.back());
156   auto zero = b.create<ConstantOp>(loc, b.getIndexType(),
157                                    b.getIntegerAttr(b.getIndexType(), 0));
158   auto one = b.create<ConstantOp>(loc, b.getIndexType(),
159                                   b.getIntegerAttr(b.getIndexType(), 1));
160   auto num_elements =
161       codegen_utils::emitNumElementsComputation(b, loc, dominant_op);
162   Value var;
163   if (parallel_loop) {
164     SmallVector<Value, 2> vars;
165     (void)createParallelAndSetInsPt(b, loc, vars, {zero}, {num_elements}, {one},
166                                     {});
167     var = vars[0];
168   } else {
169     (void)createLoopAndSetInsPt(b, loc, var, zero, num_elements, one, {});
170   }
171   for (Operation* root_op : root_ops) {
172     if (failed(lowerHelper(b, loc, root_op, var, shape_constraint_analysis)))
173       return failure();
174   }
175   // remove the root_op if it has no other users except the memref
176   if (non_fusion) {
177     for (Operation* root_op : root_ops) root_op->erase();
178   } else {
179     assert(parent != nullptr && "Parent must be provided for fusion lowering");
180     cleanUnusedLhloOps(parent);
181   }
182   return success();
183 }
184 
isOnGpu(Operation * op)185 bool isOnGpu(Operation* op) {
186   if (isa<FusionOp>(op))
187     // TODO(disc): Revisit this when fusion on cpu is suppported
188     return true;
189   assert(isa<LmhloOp>(op) && "Unexpected usage of isOnGpu");
190   auto result_memref = cast<LmhloOp>(op).getResultBuffer();
191   auto memory_space =
192       result_memref.getType().cast<MemRefType>().getMemorySpace();
193   return memory_space && memory_space.isa<StringAttr>() &&
194          memory_space.cast<StringAttr>().getValue() ==
195              mhlo::placement_utils::c_gpu;
196 }
197 
198 }  // namespace
199 
200 // Expand the root ops in a fused func into a parrallel loop or a set of
201 // nested loops. This pass must be executed after the fusion pass, and works
202 // together with the InputInlineFusion pass after it for fusion codegen.
203 //
204 // TODO(disc): Currently this pass supports lmhlo.FusionOp to have lmhlo ops
205 // inside, not mhlo. It's mainly because we now do fusion on lmhlo, not mhlo.
206 // The fusion pass can be moved to mhlo after shape dialect is brought in to
207 // represent shape calculation on tensor layer, and we would be able to do shape
208 // calculation lowering for mhlo.FusionOp. Reconsider the fusion representation
209 // after these are done, a lmhlo.FusionOp with mhlo inside would be more
210 // friendly to the legacy FusedIrEmitter.
211 class LhloLegalizeRootsToParallelLoops
212     : public LhloLegalizeRootsToParallelLoopsPassBase<
213           LhloLegalizeRootsToParallelLoops> {
runOnFunction()214   void runOnFunction() override {
215     auto func = getFunction();
216     OpBuilder b(func);
217     SmallVector<Operation*, 4> gpu_non_fusion_worklist;
218     SmallVector<Operation*, 4> cpu_non_fusion_worklist;
219     SmallVector<Operation*, 4> gpu_fusion_worklist;
220     for (mlir::Operation& op : func.body().getOps()) {
221       if (isa<FusionOp>(&op)) {
222         // TODO(disc): Revisit this when fusion on cpu is supported
223         gpu_fusion_worklist.push_back(&op);
224       } else if (isa<LmhloOp>(&op)) {
225         if (isOnGpu(&op))
226           gpu_non_fusion_worklist.push_back(&op);
227         else
228           cpu_non_fusion_worklist.push_back(&op);
229       }
230     }
231 
232     for (Operation* op : cpu_non_fusion_worklist) {
233       // Only for calculating shapes when the backend is gpu. A simple schedule
234       // should be sufficient for performance.
235       // TODO(disc): Revisit this when the backend is cpu and the calculation is
236       // for data.
237       if (failed(lowerWithScheduleLoop({op}, op, nullptr,
238                                        /*non_fusion=*/true,
239                                        /*parallel_loop=*/false))) {
240         op->emitError() << "failed to lower to loops";
241         signalPassFailure();
242         return;
243       }
244     }
245 
246     for (Operation* op : gpu_non_fusion_worklist) {
247       // TODO(disc): single nodes with non kLoop schedule like ReduceOp
248       // is not implemented yet. Currently ReduceOp is lowered with loop
249       // schedule, which means for poor performance.
250       if (failed(lowerWithScheduleLoop({op}, op, nullptr,
251                                        /*non_fusion=*/true,
252                                        /*parallel_loop=*/true))) {
253         op->emitError() << "failed to lower to loops";
254         signalPassFailure();
255         return;
256       }
257     }
258 
259     for (Operation* fusion : gpu_fusion_worklist) {
260       auto fusion_op = cast<FusionOp>(fusion);
261       FusionPattern fusion_pattern(fusion_op);
262       auto root_ops = fusion_pattern.getRootOps();
263       auto fused_block = &(fusion_op.region().front());
264       SmallVector<Operation*, 4> op_list;
265       fused_block->walk(
266           [&](LmhloOp op) { op_list.push_back(op.getOperation()); });
267       ShapeConstraintAnalysis shape_constraint_analysis(op_list);
268 
269       // No need to do codegen, return directly.
270       if (root_ops.empty()) {
271         return;
272       }
273       // Make a loop to write the buffer into init value for each
274       // ColReduction root. This will be further lowered to a init_kernel
275       // TODO(disc): Code upstream is on the way
276       // maybeEmitInitLoops(b, root_ops);
277 
278       // 1, If any reduce op among the 'root_ops', follow the schedule of it;
279       //    or else, follow the schedule of kLoop.
280       // 2, If there are a mixer of column reductions and row reductions,
281       //    follow the schedule of the row reduction, and implement all the
282       //    column reduction with the 'pure atomic' way, which has no
283       //    requirement on the schedule.
284       // TODO(disc): the support of row reduction and 'pure atomic' reduction
285       auto fusion_type = fusion_pattern.getFusionType();
286       auto dominant_op = fusion_pattern.getDominantOp();
287       switch (fusion_type) {
288         case FusionType::kRowReduction:
289           dominant_op->emitError() << "Unsupported kRowReduction Schedule";
290           signalPassFailure();
291           return;
292 
293         case FusionType::kColReduction:
294           dominant_op->emitError() << "Unsupported kColReduction Schedule";
295           signalPassFailure();
296           return;
297 
298         case FusionType::kLoop:
299           if (failed(lowerWithScheduleLoop(root_ops, dominant_op, fused_block,
300                                            /*non_fusion*/ false,
301                                            /*parallel_loop*/ true,
302                                            &shape_constraint_analysis))) {
303             dominant_op->emitError() << "failed to lower to loops";
304             signalPassFailure();
305             return;
306           }
307           break;
308         default:
309           dominant_op->emitError() << "Unknown fusion type";
310           signalPassFailure();
311           return;
312       }
313     }
314   }
315 };
316 
317 std::unique_ptr<OperationPass<FuncOp>>
createLhloLegalizeRootsToParallelLoopsPass()318 createLhloLegalizeRootsToParallelLoopsPass() {
319   return std::make_unique<LhloLegalizeRootsToParallelLoops>();
320 }
321 
322 }  // namespace lmhlo
323 }  // namespace mlir
324