• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
18 
19 #include <memory>
20 
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/Bufferize.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 
26 namespace mlir {
27 namespace mhlo {
28 
29 struct RemoveSignTypeConverter;
30 
31 // Collection of rewrite patterns for lowering a general dot product.
32 void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
33                                           MLIRContext *ctx);
34 
35 // Collection of rewrite patterns for lowering complex operations to equivalent
36 // float operations.
37 void PopulateComplexLoweringPatterns(MLIRContext *context,
38                                      OwningRewritePatternList *patterns);
39 
40 void PopulateOptimizeMHLOPatterns(MLIRContext *context,
41                                   OwningRewritePatternList *patterns);
42 
43 // Rewrite patterns for einsum to equivalent dot_general legalization.
44 void PopulateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
45                                         OwningRewritePatternList *patterns);
46 
47 // Rewrite patterns for gather to equivalent torch index select legalization.
48 void PopulateGatherToTorchIndexSelectPatterns(
49     mlir::MLIRContext *context, OwningRewritePatternList *patterns);
50 
51 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
52                                MLIRContext *ctx);
53 
54 // Collection of rewrite patterns for lowering all mhlo ops to their
55 // lmhlo counterparts.
56 void populateDynamicHLOToLHLOConversionPattern(
57     MLIRContext *context, BufferizeTypeConverter *converter,
58     OwningRewritePatternList *patterns);
59 
60 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
61 void populateHLOToLHLOConversionPattern(MLIRContext *context,
62                                         BufferizeTypeConverter *converter,
63                                         OwningRewritePatternList *patterns);
64 
65 // Collection of rewrite patterns for lowering of HLO to memref dialect.
66 // These patterns generally assume that the HLO operation are aliasing their
67 // input memrefs. If enforce_identity_map is set to true, copies will be
68 // inserted when the lowering would otherwise lead to a memref with a
69 // non-identity map.
70 void populateHLOToMemrefConversionPattern(
71     BufferizeTypeConverter *converter, RemoveSignTypeConverter *sign_converter,
72     OwningRewritePatternList *patterns, bool enforce_identity_map = true);
73 
74 // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
75 void populateHLOToLinalgConversionPattern(MLIRContext *context,
76                                           TypeConverter &typeConverter,
77                                           OwningRewritePatternList *patterns);
78 
79 // Converter to signless intergers to be used with linalg conversion patterns.
80 std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter();
81 
82 // Sets up legality definitions for materializing broadcasts.
83 void SetupMaterializeBroadcastsLegality(MLIRContext *context,
84                                         ConversionTarget *conversionTarget);
85 
86 // Populates a collection of rewrite patterns for materializing broadcast
87 // attributes to equivalent sequences of ops.
88 void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
89                                            OwningRewritePatternList *patterns);
90 
91 // Populates a collection of rewrite patterns to realize element-wise operations
92 // on ranked tensors where possible.
93 void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
94                                           OwningRewritePatternList *patterns);
95 
96 void PopulateDynamicShapeFusionPatterns(MLIRContext *context,
97                                         OwningRewritePatternList *patterns);
98 
99 // Populate a collection of conversion patterns for un-fusing
100 // batch_norm_inference and batch_norm_training into constituent HLO ops.
101 // TODO(laurenzo): Implement un-fusing of batch_norm_training.
102 void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
103                                      OwningRewritePatternList *patterns);
104 
105 // Populates patterns that translate the trigonometric operations from the
106 // standard dialect to approximations that do not use intrinsics.
107 void PopulateTrigonometricToApproximationPatterns(
108     MLIRContext *context, OwningRewritePatternList *patterns);
109 
110 // Populate patterns to move dynamic broadcasts up over element-wise operations
111 // and broadcast the operands rather than the result. This will eventually allow
112 // for larger fusions.
113 void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
114                                            OwningRewritePatternList *patterns);
115 
116 /// Populate rank specialization clustering and lowering patterns.
117 void PopulateRankSpecializationClusterPatterns(
118     MLIRContext *context, OwningRewritePatternList *patterns);
119 void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
120                                              OwningRewritePatternList *patterns,
121                                              int64_t max_target_rank);
122 
123 }  // namespace mhlo
124 
125 namespace chlo {
126 
127 // Populates a collection of conversion patterns for legalizing broadcasting
128 // client-HLO to their non-broadcasting counterparts.
129 void PopulateChloBroadcastingPatterns(MLIRContext *context,
130                                       OwningRewritePatternList *patterns);
131 
132 // Populates a collection of conversion patterns for legalizing client-HLO to
133 // HLO by decomposing client-operations to corresponding sequences of more
134 // primitive operations. This does not include the
135 // PopulateChloBroadcastingPatterns above.
136 void PopulateDecomposeChloPatterns(MLIRContext *context,
137                                    OwningRewritePatternList *patterns);
138 
139 }  // namespace chlo
140 
141 class LLVMTypeConverter;
142 class SymbolTable;
143 
144 namespace disc_ral {
145 
146 void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter *converter,
147                                              SymbolTable *symbol_table,
148                                              RewritePatternSet *patterns);
149 
150 }  // namespace disc_ral
151 
152 }  // namespace mlir
153 
154 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
155