1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ 17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ 18 19 #include <memory> 20 21 #include "llvm/ADT/ArrayRef.h" 22 23 namespace mlir { 24 25 class FuncOp; 26 class FunctionPass; 27 class ModuleOp; 28 class Operation; 29 template <typename T> 30 class OperationPass; 31 class Pass; 32 namespace lmhlo { 33 class FusionOp; 34 } 35 36 namespace mhlo { 37 38 /// Lowers HLO control flow ops to the Standard dialect. 39 std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass(); 40 41 /// Lowers MHLO control flow ops to the SCF dialect. 42 std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass(); 43 44 /// Lowers from HLO dialect to Standard dialect. 45 std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); 46 47 /// Lowers from the CHLO dialect to the HLO dialect. 48 std::unique_ptr<FunctionPass> createChloLegalizeToHloPass( 49 bool legalize_broadcasts = true, bool expand_compositions = true); 50 51 // canonicalize reduction ops to be suitable for codegen. 52 std::unique_ptr<FunctionPass> createHloCanonicalizeReductionPass(); 53 54 /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary 55 /// buffers if necessary. 56 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(); 57 58 /// Lowers from HLO dialect to Memref dialect allocating/deallocating temporary 59 /// buffers if necessary. 60 std::unique_ptr<FunctionPass> createLegalizeToMemrefPass(); 61 62 // Lowers from HLO dialect to Linalg dialect. 63 std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); 64 65 // Place shape calculating subgraph on cpu. 66 std::unique_ptr<OperationPass<ModuleOp>> createMarkShapeCalcOpPass(); 67 68 // Sinks constants implicitly captured in control flow regions. This is 69 // necessary to export to XLA. 70 std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); 71 72 // fuse mhlo ops to kLoop/kInput fusion patterns 73 std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass(); 74 75 /// Lowers trigonometric operations from the standard dialect to approximations 76 /// that do not use intrinsics. 77 std::unique_ptr<OperationPass<FuncOp>> 78 createLegalizeTrigonometricToApproximationPass(); 79 80 // Move dynamic broadcasts up over element-wise operations and broadcast the 81 // operands rather than the result. This will eventually allow for larger 82 // fusions. 83 std::unique_ptr<FunctionPass> createBroadcastPropagationPass(); 84 85 /// Rank specialization passes: 86 /// - Find compatible operations and group them together in one rank 87 /// specialization cluster. 88 /// - Lower rank specialization clusters to SCF and ranked operations. 89 std::unique_ptr<FunctionPass> createRankSpecializationClusterPass(); 90 std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass( 91 int64_t max_target_rank = 5); 92 93 std::unique_ptr<FunctionPass> createOptimizeMhloPass(); 94 std::unique_ptr<FunctionPass> createLowerComplexPass(); 95 std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); 96 std::unique_ptr<FunctionPass> createLegalizeEinsumToDotGeneralPass(); 97 std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass(); 98 std::unique_ptr<FunctionPass> createFlattenTuplePass(); 99 100 // Creates a pass for expanding mhlo.tuple ops. 101 std::unique_ptr<OperationPass<ModuleOp>> CreateExpandHloTuplesPass( 102 const std::string& entry_function_name = "main"); 103 104 } // namespace mhlo 105 106 namespace lmhlo { 107 108 // Lowers from LHLO dialect to Affine dialect. 109 std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass(); 110 111 // Lowers from LHLO dialect to Linalg dialect. 112 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass(); 113 114 // Lowers from LHLO dialect to GPU dialect. 115 std::unique_ptr<FunctionPass> createLegalizeToGpuPass(); 116 117 // Fuses linalg ops obtained after LHLO lowering. To enable fusion, 118 // operations are first tiled. 119 // 120 // When 'use_parallel_loops' is set, the tiling will use scf.parallel 121 // operations. Otherwise, scf.for operations are used. 122 // 123 // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg 124 // operation has more dimensions than tile sizes provided, 1 is used as 125 // default. 126 std::unique_ptr<FunctionPass> createLhloFuseLinalgPass( 127 bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {}); 128 129 // Lowers from LHLO dialect to parallel loops. 130 std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass(); 131 132 // Legalizes tensor load ops that are inserted during mhlo to lmhlo conversion. 133 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTensorLoadOpPass(); 134 135 // fuse lmhlo ops to kLoop/kInput fusion patterns 136 std::unique_ptr<OperationPass<FuncOp>> createLhloFusionPass( 137 int max_num_arguments_per_kernel = 64); 138 139 // inline lmhlo.Fusion 140 std::unique_ptr<OperationPass<FuncOp>> createLhloFusionInlinerPass(); 141 142 // Lowers the roots of lmhlo.fusion to parallel loops 143 std::unique_ptr<OperationPass<FuncOp>> 144 createLhloLegalizeRootsToParallelLoopsPass(); 145 146 // Input inline fusion pass for fusion codegen 147 std::unique_ptr<OperationPass<lmhlo::FusionOp>> createInputInlineFusionPass(); 148 149 } // namespace lmhlo 150 151 namespace disc_ral { 152 153 std::unique_ptr<OperationPass<ModuleOp>> createRalInjectExecutionContextPass( 154 const std::string& entry_func_name = "main"); 155 156 // Lower some specific ops to library calls (modeled by `disc_ral.launch` op). 157 std::unique_ptr<mlir::FunctionPass> createRalLowerToLibraryCallPass(); 158 159 // Lower disc to llvm dialect 160 std::unique_ptr<OperationPass<ModuleOp>> createRalToLLVMPass(); 161 162 } // namespace disc_ral 163 164 } // namespace mlir 165 166 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ 167