• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
17 
18 #include "absl/memory/memory.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
21 #include "mlir/Dialect/GPU/ParallelLoopMapper.h"  // from @llvm-project
22 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
23 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
28 #include "mlir/Transforms/LoopUtils.h"  // from @llvm-project
29 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
31 
32 namespace xla {
33 namespace mlir_gpu {
34 namespace {
35 
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
38 
39 struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::FusionOpRemoverPass40   void runOnFunction() override {
41     getFunction().walk([&](mlir::lmhlo::FusionOp op) {
42       mlir::OpBuilder builder(op);
43       // FusionOp has a single region with a single block, so we can just walk
44       // over it and clone operations to the outside.
45       mlir::BlockAndValueMapping mapping;
46       for (auto& nested_op : op.region().front().without_terminator()) {
47         auto clone = builder.clone(nested_op, mapping);
48         for (auto pair :
49              llvm::zip(nested_op.getResults(), clone->getResults())) {
50           mapping.map(std::get<0>(pair), std::get<1>(pair));
51         }
52       }
53       op.erase();
54     });
55   }
56 };
57 
58 template <typename EffectTy>
HasEffectsOnValue(mlir::Value value,mlir::Operation * op)59 bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) {
60   auto mem_effects_interface =
61       mlir::dyn_cast_or_null<mlir::MemoryEffectOpInterface>(op);
62   if (!mem_effects_interface) {
63     return false;
64   }
65   llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
66   mem_effects_interface.getEffects(effects);
67   return llvm::any_of(effects,
68                       [op](const mlir::MemoryEffects::EffectInstance& effect) {
69                         return mlir::isa<EffectTy>(effect.getEffect());
70                       });
71 }
72 
73 struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
findStorexla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass74   mlir::StoreOp findStore(mlir::Operation* op,
75                           std::function<bool(mlir::StoreOp)> matches) {
76     // Search from op upwards in the current block.
77     mlir::Block* block = op->getBlock();
78     auto startFromIt =
79         std::find_if(block->rbegin(), block->rend(),
80                      [op](mlir::Operation& other) { return &other == op; });
81     for (auto storeOpIt = startFromIt; storeOpIt != block->rend();
82          ++storeOpIt) {
83       auto storeOp = llvm::dyn_cast<mlir::StoreOp>(&*(storeOpIt));
84       if (!storeOp || !matches(storeOp)) {
85         continue;
86       }
87 
88       return storeOp;
89     }
90     // No store operation found. Continue search outside of the parallel
91     // loop if block is in a parallel loop.
92     if (auto parallelOp =
93             llvm::dyn_cast<mlir::scf::ParallelOp>(block->getParentOp())) {
94       return findStore(parallelOp.getOperation(), matches);
95     }
96     return {};
97   }
98 
99   // Recursively search defining ops for AllocOp. Return either AllocOp if it is
100   // found or nullptr.
SearchAllocOpxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass101   mlir::Operation* SearchAllocOp(mlir::Value memref) {
102     mlir::Operation* defOp = memref.getDefiningOp();
103     while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
104       defOp = subviewOp.source().getDefiningOp();
105     }
106     return HasEffectsOnValue<mlir::MemoryEffects::Allocate>(memref, defOp)
107                ? defOp
108                : nullptr;
109   }
110 
111   // Retrieves AllocOp from the cache or actually looks for it.
GetAllocOpxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass112   mlir::Operation* GetAllocOp(
113       mlir::Value memref,
114       llvm::DenseMap<mlir::Value, mlir::Operation*>* memrefToAllocOp) {
115     auto allocOpIt = memrefToAllocOp->find(memref);
116     if (allocOpIt != memrefToAllocOp->end()) {
117       return allocOpIt->second;
118     }
119     mlir::Operation* allocOp = SearchAllocOp(memref);
120     memrefToAllocOp->insert({memref, allocOp});
121     return allocOp;
122   }
123 
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::StoreForwardingPass124   void runOnFunction() override {
125     llvm::DenseMap<mlir::Value, mlir::Operation*> memrefToAllocOp;
126 
127     getFunction().walk([&](mlir::LoadOp loadOp) {
128       auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) {
129         mlir::Operation* storeOpAlloc =
130             GetAllocOp(storeOp.memref(), &memrefToAllocOp);
131         mlir::Operation* loadOpAlloc =
132             GetAllocOp(loadOp.memref(), &memrefToAllocOp);
133         return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc);
134       });
135       if (!storeOp) {
136         return;
137       }
138       auto storeIndices = storeOp.getIndices();
139       auto loadIndices = loadOp.getIndices();
140       if (!std::equal(storeIndices.begin(), storeIndices.end(),
141                       loadIndices.begin(), loadIndices.end())) {
142         return;
143       }
144       loadOp.replaceAllUsesWith(storeOp.getValueToStore());
145       loadOp.erase();
146     });
147   }
148 };
149 
150 struct DeadTempBufferRemovalPass
151     : DeadTempBufferRemovalPassBase<DeadTempBufferRemovalPass> {
operationConsideredDeadxla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass152   bool operationConsideredDead(mlir::Operation* op) {
153     for (auto result : op->getResults()) {
154       if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
155             // Store and Dealloc is OK.
156             if (llvm::isa<mlir::StoreOp, mlir::DeallocOp>(op)) {
157               return true;
158             }
159             // Load without uses is also ok.
160             if (auto loadOp = llvm::dyn_cast<mlir::LoadOp>(op)) {
161               return loadOp.use_empty();
162             }
163             // Subview is ok if it is dead itself.
164             if (llvm::isa<mlir::SubViewOp>(op)) {
165               return operationConsideredDead(op);
166             }
167             return false;
168           })) {
169         return false;
170       }
171     }
172     return true;
173   }
174 
recursiveErasexla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass175   void recursiveErase(mlir::Operation* op,
176                       llvm::SmallVectorImpl<mlir::Operation*>* erase_list) {
177     for (auto result : op->getResults()) {
178       for (auto user : llvm::make_early_inc_range(result.getUsers())) {
179         recursiveErase(user, erase_list);
180       }
181     }
182     erase_list->push_back(op);
183   }
184 
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::DeadTempBufferRemovalPass185   void runOnFunction() override {
186     llvm::SmallVector<mlir::Operation*, 8> dead_ops;
187     getFunction().walk([&](mlir::Operation* op) {
188       if (op->getNumResults() != 1 ||
189           !HasEffectsOnValue<mlir::MemoryEffects::Allocate>(op->getResult(0),
190                                                             op)) {
191         return;
192       }
193       if (!operationConsideredDead(op)) {
194         return;
195       }
196 
197       // TODO(herhut): There should be a generic helper for this.
198       recursiveErase(op, &dead_ops);
199     });
200     for (auto op : dead_ops) {
201       op->erase();
202     }
203   }
204 };
205 
206 struct RewriteKernelSignaturePass
207     : RewriteKernelSignaturePassBase<RewriteKernelSignaturePass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::RewriteKernelSignaturePass208   void runOnFunction() override {
209     mlir::FuncOp func = getFunction();
210     mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
211     getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) {
212       mlir::gpu::GPUFuncOp kernel =
213           module.lookupSymbol<mlir::gpu::GPUFuncOp>(launchOp.kernel());
214 
215       if (kernel.getNumFuncArguments() !=
216           func.getNumArguments() + func.getNumResults()) {
217         kernel.emitError()
218             << "number of kernel arguments does not match number"
219             << "of arguments and results of surrounding function";
220         signalPassFailure();
221         return;
222       }
223       if (!llvm::hasSingleElement(func)) {
224         func.emitError() << "surrounding function has more than one block";
225         signalPassFailure();
226         return;
227       }
228 
229       // Compute a map from function arguments to kernel function operands.
230       mlir::BlockAndValueMapping func_to_kernel;
231       for (mlir::BlockArgument arg : func.getArguments()) {
232         for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) {
233           if (launchOp.getKernelOperand(i) == arg) {
234             func_to_kernel.map(arg, kernel.getArgument(i));
235             break;
236           }
237         }
238       }
239       // Also add function results that are computed by the launch.
240       mlir::Operation* returnOp = func.getBody().back().getTerminator();
241       for (mlir::Value result : returnOp->getOperands()) {
242         for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) {
243           if (launchOp.getKernelOperand(i) == result) {
244             func_to_kernel.map(result, kernel.getArgument(i));
245             break;
246           }
247         }
248       }
249 
250       // Create a new kernel function with modified signature. It will have the
251       // parameters and result types of the original funcion as its parameter
252       // type and otherwise will be void.
253       auto gpu_module = kernel->getParentOfType<mlir::gpu::GPUModuleOp>();
254       mlir::OpBuilder kernel_builder(gpu_module.body());
255       auto operand_types = llvm::to_vector<4>(llvm::concat<const mlir::Type>(
256           func.getType().getInputs(), func.getType().getResults()));
257       auto new_kernel = kernel_builder.create<mlir::gpu::GPUFuncOp>(
258           kernel.getLoc(), kernel.getName(),
259           kernel_builder.getFunctionType(operand_types, {}));
260       new_kernel->setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(),
261                           kernel_builder.getUnitAttr());
262 
263       // Create a map from old kernel argument to new one.
264       mlir::BlockAndValueMapping old_kernel_to_new;
265       for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
266         mlir::Value func_arg = func.getArgument(i);
267         mlir::Value new_kernel_arg = new_kernel.getArgument(i);
268         mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg);
269         if (!old_kernel_arg) {
270           kernel.emitOpError()
271               << "argument " << i
272               << " to containing function is not an argument to the kernel";
273           signalPassFailure();
274           return;
275         }
276         old_kernel_to_new.map(old_kernel_arg, new_kernel_arg);
277       }
278       for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) {
279         mlir::Value ret_op = returnOp->getOperand(i);
280         mlir::Value new_kernel_arg =
281             new_kernel.getArgument(func.getNumArguments() + i);
282         mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op);
283         if (!old_kernel_arg) {
284           kernel.emitOpError()
285               << "result " << i
286               << " of containing function is not an argument to the kernel";
287           signalPassFailure();
288           return;
289         }
290         old_kernel_to_new.map(old_kernel_arg, new_kernel_arg);
291       }
292       // Steal the body by appending the blocks and inserting a branch.
293       kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new);
294       kernel_builder.setInsertionPointToEnd(&new_kernel.body().front());
295       kernel_builder.create<mlir::BranchOp>(
296           new_kernel.getLoc(), &*std::next(new_kernel.body().begin()));
297       // Now create a new launchOp calling the new kernel. We need to forward
298       // the arguments of the surrounding function and operands to the return.
299       mlir::SmallVector<mlir::Value, 4> new_operands;
300       new_operands.reserve(new_kernel.getNumFuncArguments());
301       new_operands.append(func.args_begin(), func.args_end());
302       new_operands.append(returnOp->operand_begin(), returnOp->operand_end());
303       mlir::OpBuilder launch_builder(launchOp);
304       launch_builder.create<mlir::gpu::LaunchFuncOp>(
305           launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(),
306           launchOp.getBlockSizeOperandValues(), new_operands);
307       // Launch does not have results, so we can just erase it. And the kernel
308       // also needs to go.
309       launchOp.erase();
310       kernel.erase();
311     });
312   }
313 };
314 
315 struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::MapParallelLoopsPass316   void runOnFunction() override {
317     mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
318   }
319 };
320 
321 struct FuseInnerParallelLoopsPass
322     : FuseInnerParallelLoopsPassBase<FuseInnerParallelLoopsPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::FuseInnerParallelLoopsPass323   void runOnFunction() override {
324     getFunction().walk([](mlir::scf::ParallelOp op) {
325       mlir::scf::naivelyFuseParallelOps(op.region());
326     });
327   }
328 };
329 
330 struct ParallelLoopCollapsingToFirstDimPass
331     : ParallelLoopCollapsingToFirstDimPassBase<
332           ParallelLoopCollapsingToFirstDimPass> {
runOnFunctionxla::mlir_gpu::__anonff3523ab0111::ParallelLoopCollapsingToFirstDimPass333   void runOnFunction() override {
334     getFunction().walk([&](mlir::scf::ParallelOp op) {
335       unsigned num_loops = op.getNumLoops();
336       std::vector<unsigned> combinedLoops;
337       combinedLoops.reserve(num_loops);
338       for (unsigned i = 0; i < num_loops; ++i) {
339         combinedLoops.push_back(i);
340       }
341       mlir::collapseParallelLoops(op, {combinedLoops});
342     });
343   }
344 };
345 
346 }  // namespace
347 
createFusionOpRemoverPass()348 std::unique_ptr<mlir::FunctionPass> createFusionOpRemoverPass() {
349   return absl::make_unique<FusionOpRemoverPass>();
350 }
351 
createStoreForwardingPass()352 std::unique_ptr<mlir::FunctionPass> createStoreForwardingPass() {
353   return absl::make_unique<StoreForwardingPass>();
354 }
355 
createDeadTempBufferRemovalPass()356 std::unique_ptr<mlir::FunctionPass> createDeadTempBufferRemovalPass() {
357   return absl::make_unique<DeadTempBufferRemovalPass>();
358 }
359 
createRewriteKernelSignaturePass()360 std::unique_ptr<mlir::FunctionPass> createRewriteKernelSignaturePass() {
361   return absl::make_unique<RewriteKernelSignaturePass>();
362 }
363 
createFuseInnerParallelLoopsPass()364 std::unique_ptr<mlir::FunctionPass> createFuseInnerParallelLoopsPass() {
365   return absl::make_unique<FuseInnerParallelLoopsPass>();
366 }
367 
createMapParallelLoopsPass()368 std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass() {
369   return absl::make_unique<MapParallelLoopsPass>();
370 }
371 
372 std::unique_ptr<mlir::FunctionPass>
createParallelLoopCollapsingToFirstDimPass()373 createParallelLoopCollapsingToFirstDimPass() {
374   return absl::make_unique<ParallelLoopCollapsingToFirstDimPass>();
375 }
376 
377 }  // namespace mlir_gpu
378 }  // namespace xla
379