• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- TestLinalgHoisting.cpp - Test Linalg hoisting functions ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for testing Linalg hoisting functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
16 #include "mlir/Pass/Pass.h"
17 
18 using namespace mlir;
19 using namespace mlir::linalg;
20 
21 namespace {
22 struct TestLinalgHoisting
23     : public PassWrapper<TestLinalgHoisting, FunctionPass> {
24   TestLinalgHoisting() = default;
TestLinalgHoisting__anon23851e290111::TestLinalgHoisting25   TestLinalgHoisting(const TestLinalgHoisting &pass) {}
getDependentDialects__anon23851e290111::TestLinalgHoisting26   void getDependentDialects(DialectRegistry &registry) const override {
27     registry.insert<AffineDialect>();
28   }
29 
30   void runOnFunction() override;
31 
32   Option<bool> testHoistViewAllocs{
33       *this, "test-hoist-view-allocs",
34       llvm::cl::desc("Test hoisting alloc used by view"),
35       llvm::cl::init(false)};
36   Option<bool> testHoistRedundantTransfers{
37       *this, "test-hoist-redundant-transfers",
38       llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"),
39       llvm::cl::init(false)};
40 };
41 } // end anonymous namespace
42 
runOnFunction()43 void TestLinalgHoisting::runOnFunction() {
44   if (testHoistViewAllocs) {
45     hoistViewAllocOps(getFunction());
46     return;
47   }
48   if (testHoistRedundantTransfers) {
49     hoistRedundantVectorTransfers(getFunction());
50     return;
51   }
52 }
53 
54 namespace mlir {
55 namespace test {
registerTestLinalgHoisting()56 void registerTestLinalgHoisting() {
57   PassRegistration<TestLinalgHoisting> testTestLinalgHoistingPass(
58       "test-linalg-hoisting", "Test Linalg hoisting functions.");
59 }
60 } // namespace test
61 } // namespace mlir
62