1 //===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg hoisting functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 16 #include "mlir/Pass/Pass.h" 17 18 using namespace mlir; 19 using namespace mlir::linalg; 20 21 namespace { 22 struct TestLinalgHoisting 23 : public PassWrapper<TestLinalgHoisting, FunctionPass> { 24 TestLinalgHoisting() = default; TestLinalgHoisting__anon23851e290111::TestLinalgHoisting25 TestLinalgHoisting(const TestLinalgHoisting &pass) {} getDependentDialects__anon23851e290111::TestLinalgHoisting26 void getDependentDialects(DialectRegistry ®istry) const override { 27 registry.insert<AffineDialect>(); 28 } 29 30 void runOnFunction() override; 31 32 Option<bool> testHoistViewAllocs{ 33 *this, "test-hoist-view-allocs", 34 llvm::cl::desc("Test hoisting alloc used by view"), 35 llvm::cl::init(false)}; 36 Option<bool> testHoistRedundantTransfers{ 37 *this, "test-hoist-redundant-transfers", 38 llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), 39 llvm::cl::init(false)}; 40 }; 41 } // end anonymous namespace 42 runOnFunction()43void TestLinalgHoisting::runOnFunction() { 44 if (testHoistViewAllocs) { 45 hoistViewAllocOps(getFunction()); 46 return; 47 } 48 if (testHoistRedundantTransfers) { 49 hoistRedundantVectorTransfers(getFunction()); 50 return; 51 } 52 } 53 54 namespace mlir { 55 namespace test { registerTestLinalgHoisting()56void registerTestLinalgHoisting() { 57 PassRegistration<TestLinalgHoisting> testTestLinalgHoistingPass( 58 "test-linalg-hoisting", "Test Linalg hoisting functions."); 59 } 60 } // namespace test 61 } // namespace mlir 62