• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18 
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
24 #include "mlir/Dialect/Complex/IR/Complex.h"  // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
26 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
27 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
28 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
29 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
30 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
31 #include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
32 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
33 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"  // from @llvm-project
34 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"  // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
36 #include "mlir/Dialect/Tensor/Transforms/Passes.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
39 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
40 #include "mlir/IR/Operation.h"  // from @llvm-project
41 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
42 #include "mlir/IR/Visitors.h"  // from @llvm-project
43 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
49 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
50 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
52 
53 namespace mlir {
54 namespace kernel_gen {
55 namespace transforms {
56 namespace {
57 
58 #define GEN_PASS_CLASSES
59 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
60 
61 /// A helper type converter class that automatically populates the relevant
62 /// materializations and type conversions for bufferization.
63 
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)64 static Value materializeTensorLoad(OpBuilder& builder, TensorType type,
65                                    ValueRange inputs, Location loc) {
66   assert(inputs.size() == 1);
67   assert(inputs[0].getType().isa<BaseMemRefType>());
68   return builder.create<TensorLoadOp>(loc, type, inputs[0]);
69 }
70 
71 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
72 class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
73  public:
CustomBufferizeTypeConverter()74   CustomBufferizeTypeConverter() {
75     // Keep all types unchanged.
76     addConversion([](Type type) { return type; });
77     // Convert RankedTensorType to MemRefType.
78     addConversion([](RankedTensorType type) -> Type {
79       return MemRefType::get(type.getShape(), type.getElementType());
80     });
81     // Convert UnrankedTensorType to UnrankedMemRefType.
82     addConversion([](UnrankedTensorType type) -> Type {
83       return UnrankedMemRefType::get(type.getElementType(), 0);
84     });
85     addArgumentMaterialization(materializeTensorLoad);
86     addSourceMaterialization(materializeTensorLoad);
87     addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
88                                 ValueRange inputs, Location loc) -> Value {
89       assert(inputs.size() == 1);
90       // Target materialization is invoked if the new operand type does not
91       // match the expected type. A special case is when the new operand type is
92       // a memref with a specified layout, i.e. non-empty affine map.
93       // TODO(pifon) : Change how target materialization is invoked in dialect
94       // conversion.
95       if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
96         assert(!memref_type.getAffineMaps().empty());
97         return inputs[0];
98       }
99       assert(inputs[0].getType().isa<TensorType>());
100       return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
101     });
102   }
103 };
104 
105 struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
106   // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon65e8b27c0111::HloBufferizePass107   void getDependentDialects(DialectRegistry& registry) const override {
108     registry.insert<lmhlo::LmhloDialect>();
109   }
110 
111  public:
runOnOperationmlir::kernel_gen::transforms::__anon65e8b27c0111::HloBufferizePass112   void runOnOperation() override {
113     OwningRewritePatternList patterns;
114     auto& context = getContext();
115     ConversionTarget target(context);
116     target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
117                            StandardOpsDialect, tensor::TensorDialect,
118                            math::MathDialect>();
119     target.addIllegalDialect<mhlo::MhloDialect>();
120 
121     CustomBufferizeTypeConverter converter;
122     // Configure bufferize pattern for functions and lhlo.
123     mhlo::populateDynamicHLOToLHLOConversionPattern(
124         &context, &converter, &patterns, /*insert_copy=*/false);
125     populateFuncOpTypeConversionPattern(patterns, &context, converter);
126     populateCallOpTypeConversionPattern(patterns, &context, converter);
127     populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
128         patterns, &context, converter);
129 
130     // Configure legality and structural patterns.
131     populateBufferizeMaterializationLegality(target);
132     linalg::populateLinalgBufferizePatterns(&context, converter, patterns);
133     populateShapeStructuralTypeConversionsAndLegality(&context, converter,
134                                                       patterns, target);
135     scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
136                                                          patterns, target);
137     // TODO(herhut): Move this legality configuration to bufferize itself?
138     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
139       auto inputs = op.getType().getInputs();
140       auto results = op.getType().getResults();
141       return converter.isLegal(inputs) && converter.isLegal(results) &&
142              converter.isLegal(&op.getBody());
143     });
144     auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
145     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
146     target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
147 
148     if (failed(applyPartialConversion(getOperation(), target,
149                                       std::move(patterns))))
150       signalPassFailure();
151   }
152 };
153 
154 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
155   // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon65e8b27c0111::FinalBufferizePass156   void getDependentDialects(DialectRegistry& registry) const override {
157     registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
158                     tensor::TensorDialect, tf_framework::TFFrameworkDialect,
159                     lmhlo::LmhloDialect>();
160   }
161 
162  public:
runOnOperationmlir::kernel_gen::transforms::__anon65e8b27c0111::FinalBufferizePass163   void runOnOperation() override {
164     auto& context = getContext();
165     ConversionTarget target(context);
166     target.addLegalDialect<complex::ComplexDialect, scf::SCFDialect,
167                            StandardOpsDialect, tensor::TensorDialect,
168                            tf_framework::TFFrameworkDialect, AffineDialect,
169                            shape::ShapeDialect, lmhlo::LmhloDialect,
170                            linalg::LinalgDialect, math::MathDialect>();
171     target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
172 
173     target.addIllegalDialect<mhlo::MhloDialect>();
174     target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
175                         tensor::FromElementsOp, tensor::CastOp, TensorLoadOp,
176                         TensorToMemrefOp>();
177     BufferizeTypeConverter converter;
178     auto typesAreLegal = [&converter](Operation* op) {
179       return converter.isLegal(op->getOperandTypes()) &&
180              converter.isLegal(op->getResultTypes());
181     };
182     target.addDynamicallyLegalOp<ConstantOp, DimOp, RankOp, SelectOp>(
183         typesAreLegal);
184 
185     OwningRewritePatternList patterns;
186     populateTensorBufferizePatterns(&context, converter, patterns);
187     populateStdBufferizePatterns(&context, converter, patterns);
188     populateEliminateBufferizeMaterializationsPatterns(&context, converter,
189                                                        patterns);
190     populateExtraStdBufferizePattern(&context, &converter, &patterns);
191     populateShapeStructuralTypeConversionsAndLegality(&context, converter,
192                                                       patterns, target);
193     scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
194                                                          patterns, target);
195 
196     auto module = getOperation();
197     if (failed(applyFullConversion(module, target, std::move(patterns)))) {
198       signalPassFailure();
199     }
200   }
201 };
202 
203 }  // namespace
204 
CreateHloBufferizePass()205 std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() {
206   return std::make_unique<HloBufferizePass>();
207 }
208 
CreateFinalBufferizePass()209 std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
210   return std::make_unique<FinalBufferizePass>();
211 }
212 
213 }  // namespace transforms
214 }  // namespace kernel_gen
215 }  // namespace mlir
216