1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
17
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
21 #include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
23 #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
28 #include "mlir/Transforms/LoopUtils.h" // from @llvm-project
29 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
31
32 namespace xla {
33 namespace mlir_gpu {
34 namespace {
35
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
38
39 struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::FusionOpRemoverPass40 void runOnFunction() override {
41 getFunction().walk([&](mlir::lmhlo::FusionOp op) {
42 mlir::OpBuilder builder(op);
43 // FusionOp has a single region with a single block, so we can just walk
44 // over it and clone operations to the outside.
45 mlir::BlockAndValueMapping mapping;
46 for (auto& nested_op : op.region().front().without_terminator()) {
47 auto clone = builder.clone(nested_op, mapping);
48 for (auto pair :
49 llvm::zip(nested_op.getResults(), clone->getResults())) {
50 mapping.map(std::get<0>(pair), std::get<1>(pair));
51 }
52 }
53 op.erase();
54 });
55 }
56 };
57
58 template <typename EffectTy>
HasEffectsOnValue(mlir::Value value,mlir::Operation * op)59 bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) {
60 auto mem_effects_interface =
61 mlir::dyn_cast_or_null<mlir::MemoryEffectOpInterface>(op);
62 if (!mem_effects_interface) {
63 return false;
64 }
65 llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
66 mem_effects_interface.getEffects(effects);
67 return llvm::any_of(effects,
68 [op](const mlir::MemoryEffects::EffectInstance& effect) {
69 return mlir::isa<EffectTy>(effect.getEffect());
70 });
71 }
72
73 struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
findStorexla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass74 mlir::StoreOp findStore(mlir::Operation* op,
75 std::function<bool(mlir::StoreOp)> matches) {
76 // Search from op upwards in the current block.
77 mlir::Block* block = op->getBlock();
78 auto startFromIt =
79 std::find_if(block->rbegin(), block->rend(),
80 [op](mlir::Operation& other) { return &other == op; });
81 for (auto storeOpIt = startFromIt; storeOpIt != block->rend();
82 ++storeOpIt) {
83 auto storeOp = llvm::dyn_cast<mlir::StoreOp>(&*(storeOpIt));
84 if (!storeOp || !matches(storeOp)) {
85 continue;
86 }
87
88 return storeOp;
89 }
90 // No store operation found. Continue search outside of the parallel
91 // loop if block is in a parallel loop.
92 if (auto parallelOp =
93 llvm::dyn_cast<mlir::scf::ParallelOp>(block->getParentOp())) {
94 return findStore(parallelOp.getOperation(), matches);
95 }
96 return {};
97 }
98
99 // Recursively search defining ops for AllocOp. Return either AllocOp if it is
100 // found or nullptr.
SearchAllocOpxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass101 mlir::Operation* SearchAllocOp(mlir::Value memref) {
102 mlir::Operation* defOp = memref.getDefiningOp();
103 while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
104 defOp = subviewOp.source().getDefiningOp();
105 }
106 return HasEffectsOnValue<mlir::MemoryEffects::Allocate>(memref, defOp)
107 ? defOp
108 : nullptr;
109 }
110
111 // Retrieves AllocOp from the cache or actually looks for it.
GetAllocOpxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass112 mlir::Operation* GetAllocOp(
113 mlir::Value memref,
114 llvm::DenseMap<mlir::Value, mlir::Operation*>* memrefToAllocOp) {
115 auto allocOpIt = memrefToAllocOp->find(memref);
116 if (allocOpIt != memrefToAllocOp->end()) {
117 return allocOpIt->second;
118 }
119 mlir::Operation* allocOp = SearchAllocOp(memref);
120 memrefToAllocOp->insert({memref, allocOp});
121 return allocOp;
122 }
123
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass124 void runOnFunction() override {
125 llvm::DenseMap<mlir::Value, mlir::Operation*> memrefToAllocOp;
126
127 getFunction().walk([&](mlir::LoadOp loadOp) {
128 auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) {
129 mlir::Operation* storeOpAlloc =
130 GetAllocOp(storeOp.memref(), &memrefToAllocOp);
131 mlir::Operation* loadOpAlloc =
132 GetAllocOp(loadOp.memref(), &memrefToAllocOp);
133 return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc);
134 });
135 if (!storeOp) {
136 return;
137 }
138 auto storeIndices = storeOp.getIndices();
139 auto loadIndices = loadOp.getIndices();
140 if (!std::equal(storeIndices.begin(), storeIndices.end(),
141 loadIndices.begin(), loadIndices.end())) {
142 return;
143 }
144 loadOp.replaceAllUsesWith(storeOp.getValueToStore());
145 loadOp.erase();
146 });
147 }
148 };
149
150 struct DeadTempBufferRemovalPass
151 : DeadTempBufferRemovalPassBase<DeadTempBufferRemovalPass> {
operationConsideredDeadxla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass152 bool operationConsideredDead(mlir::Operation* op) {
153 for (auto result : op->getResults()) {
154 if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
155 // Store and Dealloc is OK.
156 if (llvm::isa<mlir::StoreOp, mlir::DeallocOp>(op)) {
157 return true;
158 }
159 // Load without uses is also ok.
160 if (auto loadOp = llvm::dyn_cast<mlir::LoadOp>(op)) {
161 return loadOp.use_empty();
162 }
163 // Subview is ok if it is dead itself.
164 if (llvm::isa<mlir::SubViewOp>(op)) {
165 return operationConsideredDead(op);
166 }
167 return false;
168 })) {
169 return false;
170 }
171 }
172 return true;
173 }
174
recursiveErasexla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass175 void recursiveErase(mlir::Operation* op,
176 llvm::SmallVectorImpl<mlir::Operation*>* erase_list) {
177 for (auto result : op->getResults()) {
178 for (auto user : llvm::make_early_inc_range(result.getUsers())) {
179 recursiveErase(user, erase_list);
180 }
181 }
182 erase_list->push_back(op);
183 }
184
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass185 void runOnFunction() override {
186 llvm::SmallVector<mlir::Operation*, 8> dead_ops;
187 getFunction().walk([&](mlir::Operation* op) {
188 if (op->getNumResults() != 1 ||
189 !HasEffectsOnValue<mlir::MemoryEffects::Allocate>(op->getResult(0),
190 op)) {
191 return;
192 }
193 if (!operationConsideredDead(op)) {
194 return;
195 }
196
197 // TODO(herhut): There should be a generic helper for this.
198 recursiveErase(op, &dead_ops);
199 });
200 for (auto op : dead_ops) {
201 op->erase();
202 }
203 }
204 };
205
206 struct RewriteKernelSignaturePass
207 : RewriteKernelSignaturePassBase<RewriteKernelSignaturePass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::RewriteKernelSignaturePass208 void runOnFunction() override {
209 mlir::FuncOp func = getFunction();
210 mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
211 getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) {
212 mlir::gpu::GPUFuncOp kernel =
213 module.lookupSymbol<mlir::gpu::GPUFuncOp>(launchOp.kernel());
214
215 if (kernel.getNumFuncArguments() !=
216 func.getNumArguments() + func.getNumResults()) {
217 kernel.emitError()
218 << "number of kernel arguments does not match number"
219 << "of arguments and results of surrounding function";
220 signalPassFailure();
221 return;
222 }
223 if (!llvm::hasSingleElement(func)) {
224 func.emitError() << "surrounding function has more than one block";
225 signalPassFailure();
226 return;
227 }
228
229 // Compute a map from function arguments to kernel function operands.
230 mlir::BlockAndValueMapping func_to_kernel;
231 for (mlir::BlockArgument arg : func.getArguments()) {
232 for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) {
233 if (launchOp.getKernelOperand(i) == arg) {
234 func_to_kernel.map(arg, kernel.getArgument(i));
235 break;
236 }
237 }
238 }
239 // Also add function results that are computed by the launch.
240 mlir::Operation* returnOp = func.getBody().back().getTerminator();
241 for (mlir::Value result : returnOp->getOperands()) {
242 for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) {
243 if (launchOp.getKernelOperand(i) == result) {
244 func_to_kernel.map(result, kernel.getArgument(i));
245 break;
246 }
247 }
248 }
249
250 // Create a new kernel function with modified signature. It will have the
251 // parameters and result types of the original funcion as its parameter
252 // type and otherwise will be void.
253 auto gpu_module = kernel->getParentOfType<mlir::gpu::GPUModuleOp>();
254 mlir::OpBuilder kernel_builder(gpu_module.body());
255 auto operand_types = llvm::to_vector<4>(llvm::concat<const mlir::Type>(
256 func.getType().getInputs(), func.getType().getResults()));
257 auto new_kernel = kernel_builder.create<mlir::gpu::GPUFuncOp>(
258 kernel.getLoc(), kernel.getName(),
259 kernel_builder.getFunctionType(operand_types, {}));
260 new_kernel->setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(),
261 kernel_builder.getUnitAttr());
262
263 // Create a map from old kernel argument to new one.
264 mlir::BlockAndValueMapping old_kernel_to_new;
265 for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
266 mlir::Value func_arg = func.getArgument(i);
267 mlir::Value new_kernel_arg = new_kernel.getArgument(i);
268 mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg);
269 if (!old_kernel_arg) {
270 kernel.emitOpError()
271 << "argument " << i
272 << " to containing function is not an argument to the kernel";
273 signalPassFailure();
274 return;
275 }
276 old_kernel_to_new.map(old_kernel_arg, new_kernel_arg);
277 }
278 for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) {
279 mlir::Value ret_op = returnOp->getOperand(i);
280 mlir::Value new_kernel_arg =
281 new_kernel.getArgument(func.getNumArguments() + i);
282 mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op);
283 if (!old_kernel_arg) {
284 kernel.emitOpError()
285 << "result " << i
286 << " of containing function is not an argument to the kernel";
287 signalPassFailure();
288 return;
289 }
290 old_kernel_to_new.map(old_kernel_arg, new_kernel_arg);
291 }
292 // Steal the body by appending the blocks and inserting a branch.
293 kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new);
294 kernel_builder.setInsertionPointToEnd(&new_kernel.body().front());
295 kernel_builder.create<mlir::BranchOp>(
296 new_kernel.getLoc(), &*std::next(new_kernel.body().begin()));
297 // Now create a new launchOp calling the new kernel. We need to forward
298 // the arguments of the surrounding function and operands to the return.
299 mlir::SmallVector<mlir::Value, 4> new_operands;
300 new_operands.reserve(new_kernel.getNumFuncArguments());
301 new_operands.append(func.args_begin(), func.args_end());
302 new_operands.append(returnOp->operand_begin(), returnOp->operand_end());
303 mlir::OpBuilder launch_builder(launchOp);
304 launch_builder.create<mlir::gpu::LaunchFuncOp>(
305 launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(),
306 launchOp.getBlockSizeOperandValues(), new_operands);
307 // Launch does not have results, so we can just erase it. And the kernel
308 // also needs to go.
309 launchOp.erase();
310 kernel.erase();
311 });
312 }
313 };
314
315 struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::MapParallelLoopsPass316 void runOnFunction() override {
317 mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
318 }
319 };
320
321 struct FuseInnerParallelLoopsPass
322 : FuseInnerParallelLoopsPassBase<FuseInnerParallelLoopsPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::FuseInnerParallelLoopsPass323 void runOnFunction() override {
324 getFunction().walk([](mlir::scf::ParallelOp op) {
325 mlir::scf::naivelyFuseParallelOps(op.region());
326 });
327 }
328 };
329
330 struct ParallelLoopCollapsingToFirstDimPass
331 : ParallelLoopCollapsingToFirstDimPassBase<
332 ParallelLoopCollapsingToFirstDimPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::ParallelLoopCollapsingToFirstDimPass333 void runOnFunction() override {
334 getFunction().walk([&](mlir::scf::ParallelOp op) {
335 unsigned num_loops = op.getNumLoops();
336 std::vector<unsigned> combinedLoops;
337 combinedLoops.reserve(num_loops);
338 for (unsigned i = 0; i < num_loops; ++i) {
339 combinedLoops.push_back(i);
340 }
341 mlir::collapseParallelLoops(op, {combinedLoops});
342 });
343 }
344 };
345
346 } // namespace
347
createFusionOpRemoverPass()348 std::unique_ptr<mlir::FunctionPass> createFusionOpRemoverPass() {
349 return absl::make_unique<FusionOpRemoverPass>();
350 }
351
createStoreForwardingPass()352 std::unique_ptr<mlir::FunctionPass> createStoreForwardingPass() {
353 return absl::make_unique<StoreForwardingPass>();
354 }
355
createDeadTempBufferRemovalPass()356 std::unique_ptr<mlir::FunctionPass> createDeadTempBufferRemovalPass() {
357 return absl::make_unique<DeadTempBufferRemovalPass>();
358 }
359
createRewriteKernelSignaturePass()360 std::unique_ptr<mlir::FunctionPass> createRewriteKernelSignaturePass() {
361 return absl::make_unique<RewriteKernelSignaturePass>();
362 }
363
createFuseInnerParallelLoopsPass()364 std::unique_ptr<mlir::FunctionPass> createFuseInnerParallelLoopsPass() {
365 return absl::make_unique<FuseInnerParallelLoopsPass>();
366 }
367
createMapParallelLoopsPass()368 std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass() {
369 return absl::make_unique<MapParallelLoopsPass>();
370 }
371
372 std::unique_ptr<mlir::FunctionPass>
createParallelLoopCollapsingToFirstDimPass()373 createParallelLoopCollapsingToFirstDimPass() {
374 return absl::make_unique<ParallelLoopCollapsingToFirstDimPass>();
375 }
376
377 } // namespace mlir_gpu
378 } // namespace xla
379