1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17 #include "llvm/Support/Debug.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/fusion_utils.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/lhlo_elemental_utils.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
22 #include "mlir-hlo/utils/codegen_utils.h"
23 #include "mlir-hlo/utils/placement_utils.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/SCF/SCF.h"
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/BlockAndValueMapping.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Location.h"
32 #include "mlir/IR/MLIRContext.h"
33 #include "mlir/IR/PatternMatch.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Support/LogicalResult.h"
36
37 using mlir::codegen_utils::calcMultiDimIndex;
38
39 namespace mlir {
40 namespace lmhlo {
41
42 #define GEN_PASS_CLASSES
43 #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
44
45 namespace {
46
47 template <typename LHLO_OpTy>
elemwiseLowerHelper(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)48 LogicalResult elemwiseLowerHelper(
49 OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
50 const ShapeConstraintAnalysis* shape_constraint_analysis) {
51 if (!isa<LHLO_OpTy>(op) || !op->hasTrait<mlir::OpTrait::Elementwise>())
52 return failure();
53
54 Value result_memref = cast<LmhloOp>(op).getResultBuffer();
55 Value memref = result_memref;
56 if (shape_constraint_analysis) {
57 Value leader_memref =
58 shape_constraint_analysis->GetLeaderValueWithSameShape(result_memref);
59 if (leader_memref != nullptr) memref = leader_memref;
60 }
61 // TODO(disc): Replace with memref.Delinearize
62 auto multidim_index = calcMultiDimIndex(b, loc, output_linear_index, memref);
63 SmallVector<Value, 4> operand_values;
64 for (Value operand_memref : op->getOperands().drop_back()) {
65 Value operand_data = createLoadOrUseCachedValue(
66 loc, &b, operand_memref, multidim_index, b.saveInsertionPoint());
67 operand_values.push_back(operand_data);
68 }
69 auto res = HloOpToStdScalarOp::map<LHLO_OpTy>(
70 llvm::cast<LHLO_OpTy>(op),
71 result_memref.getType().cast<MemRefType>().getElementType(),
72 operand_values, &b);
73 createOffsetStore(b, loc, res, result_memref, output_linear_index);
74 return success();
75 }
76
77 template <typename LHLO_OpTy>
miscLowerHelper(OpBuilder & b,Location loc,Operation * opaque_op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)78 LogicalResult miscLowerHelper(
79 OpBuilder& b, Location loc, Operation* opaque_op, Value output_linear_index,
80 const ShapeConstraintAnalysis* shape_constraint_analysis) {
81 LHLO_OpTy op = dyn_cast<LHLO_OpTy>(opaque_op);
82 if (!op) return failure();
83 Value result_memref = cast<LmhloOp>(&*op).getResultBuffer();
84 Value memref = result_memref;
85 if (shape_constraint_analysis) {
86 Value leader_memref =
87 shape_constraint_analysis->GetLeaderValueWithSameShape(result_memref);
88 if (leader_memref != nullptr) {
89 memref = leader_memref;
90 }
91 }
92 llvm::SmallVector<Value, 4> output_multidim_index =
93 calcMultiDimIndex(b, loc, output_linear_index, memref);
94 Value operand_data = elementalLower(&b, loc, op, output_multidim_index,
95 /*check_cache=*/true);
96 createOffsetStore(b, loc, operand_data, result_memref, output_linear_index);
97 return success();
98 }
99
100 template <typename First>
elemwiseLowerHelperOr(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)101 LogicalResult elemwiseLowerHelperOr(
102 OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
103 const ShapeConstraintAnalysis* shape_constraint_analysis) {
104 return elemwiseLowerHelper<First>(b, loc, op, output_linear_index,
105 shape_constraint_analysis);
106 }
107
108 template <typename First, typename Second, typename... Rest>
elemwiseLowerHelperOr(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)109 LogicalResult elemwiseLowerHelperOr(
110 OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
111 const ShapeConstraintAnalysis* shape_constraint_analysis) {
112 return success(
113 succeeded(elemwiseLowerHelperOr<First>(b, loc, op, output_linear_index,
114 shape_constraint_analysis)) ||
115 succeeded(elemwiseLowerHelperOr<Second, Rest...>(
116 b, loc, op, output_linear_index, shape_constraint_analysis)));
117 }
118
lowerHelper(OpBuilder & b,Location loc,Operation * op,Value output_linear_index,const ShapeConstraintAnalysis * shape_constraint_analysis)119 LogicalResult lowerHelper(
120 OpBuilder& b, Location loc, Operation* op, Value output_linear_index,
121 const ShapeConstraintAnalysis* shape_constraint_analysis) {
122 if (succeeded(
123 elemwiseLowerHelperOr<
124 #define GET_SUPPORTED_OP_LIST
125 #include "mlir-hlo/utils/disc_supported_list.h.inc"
126 >(b, loc, op, output_linear_index, shape_constraint_analysis)) ||
127 // TODO(disc): Upstream is on the way for more Ops
128 succeeded(miscLowerHelper<RealDynamicSliceOp>(
129 b, loc, op, output_linear_index, shape_constraint_analysis)) ||
130 succeeded(miscLowerHelper<DynamicBroadcastInDimOp>(
131 b, loc, op, output_linear_index, shape_constraint_analysis)) ||
132 succeeded(miscLowerHelper<BroadcastInDimOp>(
133 b, loc, op, output_linear_index, shape_constraint_analysis))) {
134 return success();
135 }
136 return failure();
137 }
138
139 // we don't do inbound check for kLoop Schedule
140 // LoopSplit pass will do this.
141 //
142 /* %num_elements = ElementsIn(root_shape)
143 * loop.for %idx = 0 to %num_elements step 1 {
144 * %multidim_indices_0..n = getMultidimIndices(%idx);
145 * %operand_0 = load %operand0[]
146 * %operand_1 = load %operand1[]
147 * emit calculation..
148 * }
149 */
lowerWithScheduleLoop(ArrayRef<Operation * > root_ops,Operation * dominant_op,Block * parent=nullptr,bool non_fusion=false,bool parallel_loop=true,const ShapeConstraintAnalysis * shape_constraint_analysis=nullptr)150 LogicalResult lowerWithScheduleLoop(
151 ArrayRef<Operation*> root_ops, Operation* dominant_op,
152 Block* parent = nullptr, bool non_fusion = false, bool parallel_loop = true,
153 const ShapeConstraintAnalysis* shape_constraint_analysis = nullptr) {
154 const auto loc = dominant_op->getLoc();
155 OpBuilder b(root_ops.back());
156 auto zero = b.create<ConstantOp>(loc, b.getIndexType(),
157 b.getIntegerAttr(b.getIndexType(), 0));
158 auto one = b.create<ConstantOp>(loc, b.getIndexType(),
159 b.getIntegerAttr(b.getIndexType(), 1));
160 auto num_elements =
161 codegen_utils::emitNumElementsComputation(b, loc, dominant_op);
162 Value var;
163 if (parallel_loop) {
164 SmallVector<Value, 2> vars;
165 (void)createParallelAndSetInsPt(b, loc, vars, {zero}, {num_elements}, {one},
166 {});
167 var = vars[0];
168 } else {
169 (void)createLoopAndSetInsPt(b, loc, var, zero, num_elements, one, {});
170 }
171 for (Operation* root_op : root_ops) {
172 if (failed(lowerHelper(b, loc, root_op, var, shape_constraint_analysis)))
173 return failure();
174 }
175 // remove the root_op if it has no other users except the memref
176 if (non_fusion) {
177 for (Operation* root_op : root_ops) root_op->erase();
178 } else {
179 assert(parent != nullptr && "Parent must be provided for fusion lowering");
180 cleanUnusedLhloOps(parent);
181 }
182 return success();
183 }
184
isOnGpu(Operation * op)185 bool isOnGpu(Operation* op) {
186 if (isa<FusionOp>(op))
187 // TODO(disc): Revisit this when fusion on cpu is suppported
188 return true;
189 assert(isa<LmhloOp>(op) && "Unexpected usage of isOnGpu");
190 auto result_memref = cast<LmhloOp>(op).getResultBuffer();
191 auto memory_space =
192 result_memref.getType().cast<MemRefType>().getMemorySpace();
193 return memory_space && memory_space.isa<StringAttr>() &&
194 memory_space.cast<StringAttr>().getValue() ==
195 mhlo::placement_utils::c_gpu;
196 }
197
198 } // namespace
199
200 // Expand the root ops in a fused func into a parrallel loop or a set of
201 // nested loops. This pass must be executed after the fusion pass, and works
202 // together with the InputInlineFusion pass after it for fusion codegen.
203 //
204 // TODO(disc): Currently this pass supports lmhlo.FusionOp to have lmhlo ops
205 // inside, not mhlo. It's mainly because we now do fusion on lmhlo, not mhlo.
206 // The fusion pass can be moved to mhlo after shape dialect is brought in to
207 // represent shape calculation on tensor layer, and we would be able to do shape
208 // calculation lowering for mhlo.FusionOp. Reconsider the fusion representation
209 // after these are done, a lmhlo.FusionOp with mhlo inside would be more
210 // friendly to the legacy FusedIrEmitter.
211 class LhloLegalizeRootsToParallelLoops
212 : public LhloLegalizeRootsToParallelLoopsPassBase<
213 LhloLegalizeRootsToParallelLoops> {
runOnFunction()214 void runOnFunction() override {
215 auto func = getFunction();
216 OpBuilder b(func);
217 SmallVector<Operation*, 4> gpu_non_fusion_worklist;
218 SmallVector<Operation*, 4> cpu_non_fusion_worklist;
219 SmallVector<Operation*, 4> gpu_fusion_worklist;
220 for (mlir::Operation& op : func.body().getOps()) {
221 if (isa<FusionOp>(&op)) {
222 // TODO(disc): Revisit this when fusion on cpu is supported
223 gpu_fusion_worklist.push_back(&op);
224 } else if (isa<LmhloOp>(&op)) {
225 if (isOnGpu(&op))
226 gpu_non_fusion_worklist.push_back(&op);
227 else
228 cpu_non_fusion_worklist.push_back(&op);
229 }
230 }
231
232 for (Operation* op : cpu_non_fusion_worklist) {
233 // Only for calculating shapes when the backend is gpu. A simple schedule
234 // should be sufficient for performance.
235 // TODO(disc): Revisit this when the backend is cpu and the calculation is
236 // for data.
237 if (failed(lowerWithScheduleLoop({op}, op, nullptr,
238 /*non_fusion=*/true,
239 /*parallel_loop=*/false))) {
240 op->emitError() << "failed to lower to loops";
241 signalPassFailure();
242 return;
243 }
244 }
245
246 for (Operation* op : gpu_non_fusion_worklist) {
247 // TODO(disc): single nodes with non kLoop schedule like ReduceOp
248 // is not implemented yet. Currently ReduceOp is lowered with loop
249 // schedule, which means for poor performance.
250 if (failed(lowerWithScheduleLoop({op}, op, nullptr,
251 /*non_fusion=*/true,
252 /*parallel_loop=*/true))) {
253 op->emitError() << "failed to lower to loops";
254 signalPassFailure();
255 return;
256 }
257 }
258
259 for (Operation* fusion : gpu_fusion_worklist) {
260 auto fusion_op = cast<FusionOp>(fusion);
261 FusionPattern fusion_pattern(fusion_op);
262 auto root_ops = fusion_pattern.getRootOps();
263 auto fused_block = &(fusion_op.region().front());
264 SmallVector<Operation*, 4> op_list;
265 fused_block->walk(
266 [&](LmhloOp op) { op_list.push_back(op.getOperation()); });
267 ShapeConstraintAnalysis shape_constraint_analysis(op_list);
268
269 // No need to do codegen, return directly.
270 if (root_ops.empty()) {
271 return;
272 }
273 // Make a loop to write the buffer into init value for each
274 // ColReduction root. This will be further lowered to a init_kernel
275 // TODO(disc): Code upstream is on the way
276 // maybeEmitInitLoops(b, root_ops);
277
278 // 1, If any reduce op among the 'root_ops', follow the schedule of it;
279 // or else, follow the schedule of kLoop.
280 // 2, If there are a mixer of column reductions and row reductions,
281 // follow the schedule of the row reduction, and implement all the
282 // column reduction with the 'pure atomic' way, which has no
283 // requirement on the schedule.
284 // TODO(disc): the support of row reduction and 'pure atomic' reduction
285 auto fusion_type = fusion_pattern.getFusionType();
286 auto dominant_op = fusion_pattern.getDominantOp();
287 switch (fusion_type) {
288 case FusionType::kRowReduction:
289 dominant_op->emitError() << "Unsupported kRowReduction Schedule";
290 signalPassFailure();
291 return;
292
293 case FusionType::kColReduction:
294 dominant_op->emitError() << "Unsupported kColReduction Schedule";
295 signalPassFailure();
296 return;
297
298 case FusionType::kLoop:
299 if (failed(lowerWithScheduleLoop(root_ops, dominant_op, fused_block,
300 /*non_fusion*/ false,
301 /*parallel_loop*/ true,
302 &shape_constraint_analysis))) {
303 dominant_op->emitError() << "failed to lower to loops";
304 signalPassFailure();
305 return;
306 }
307 break;
308 default:
309 dominant_op->emitError() << "Unknown fusion type";
310 signalPassFailure();
311 return;
312 }
313 }
314 }
315 };
316
317 std::unique_ptr<OperationPass<FuncOp>>
createLhloLegalizeRootsToParallelLoopsPass()318 createLhloLegalizeRootsToParallelLoopsPass() {
319 return std::make_unique<LhloLegalizeRootsToParallelLoops>();
320 }
321
322 } // namespace lmhlo
323 } // namespace mlir
324