1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
24 #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
25 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
26 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
27 #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
28 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
29 #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
30 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
31 #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
32 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
33 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" // from @llvm-project
34 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
36 #include "mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
38 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
39 #include "mlir/IR/MLIRContext.h" // from @llvm-project
40 #include "mlir/IR/Operation.h" // from @llvm-project
41 #include "mlir/IR/PatternMatch.h" // from @llvm-project
42 #include "mlir/IR/Visitors.h" // from @llvm-project
43 #include "mlir/Transforms/Bufferize.h" // from @llvm-project
44 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
46 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
49 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
50 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
52
53 namespace mlir {
54 namespace kernel_gen {
55 namespace transforms {
56 namespace {
57
58 #define GEN_PASS_CLASSES
59 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
60
61 /// A helper type converter class that automatically populates the relevant
62 /// materializations and type conversions for bufferization.
63
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)64 static Value materializeTensorLoad(OpBuilder& builder, TensorType type,
65 ValueRange inputs, Location loc) {
66 assert(inputs.size() == 1);
67 assert(inputs[0].getType().isa<BaseMemRefType>());
68 return builder.create<TensorLoadOp>(loc, type, inputs[0]);
69 }
70
71 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
72 class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
73 public:
CustomBufferizeTypeConverter()74 CustomBufferizeTypeConverter() {
75 // Keep all types unchanged.
76 addConversion([](Type type) { return type; });
77 // Convert RankedTensorType to MemRefType.
78 addConversion([](RankedTensorType type) -> Type {
79 return MemRefType::get(type.getShape(), type.getElementType());
80 });
81 // Convert UnrankedTensorType to UnrankedMemRefType.
82 addConversion([](UnrankedTensorType type) -> Type {
83 return UnrankedMemRefType::get(type.getElementType(), 0);
84 });
85 addArgumentMaterialization(materializeTensorLoad);
86 addSourceMaterialization(materializeTensorLoad);
87 addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
88 ValueRange inputs, Location loc) -> Value {
89 assert(inputs.size() == 1);
90 // Target materialization is invoked if the new operand type does not
91 // match the expected type. A special case is when the new operand type is
92 // a memref with a specified layout, i.e. non-empty affine map.
93 // TODO(pifon) : Change how target materialization is invoked in dialect
94 // conversion.
95 if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
96 assert(!memref_type.getAffineMaps().empty());
97 return inputs[0];
98 }
99 assert(inputs[0].getType().isa<TensorType>());
100 return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
101 });
102 }
103 };
104
105 struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
106 // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon65e8b27c0111::HloBufferizePass107 void getDependentDialects(DialectRegistry& registry) const override {
108 registry.insert<lmhlo::LmhloDialect>();
109 }
110
111 public:
runOnOperationmlir::kernel_gen::transforms::__anon65e8b27c0111::HloBufferizePass112 void runOnOperation() override {
113 OwningRewritePatternList patterns;
114 auto& context = getContext();
115 ConversionTarget target(context);
116 target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
117 StandardOpsDialect, tensor::TensorDialect,
118 math::MathDialect>();
119 target.addIllegalDialect<mhlo::MhloDialect>();
120
121 CustomBufferizeTypeConverter converter;
122 // Configure bufferize pattern for functions and lhlo.
123 mhlo::populateDynamicHLOToLHLOConversionPattern(
124 &context, &converter, &patterns, /*insert_copy=*/false);
125 populateFuncOpTypeConversionPattern(patterns, &context, converter);
126 populateCallOpTypeConversionPattern(patterns, &context, converter);
127 populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
128 patterns, &context, converter);
129
130 // Configure legality and structural patterns.
131 populateBufferizeMaterializationLegality(target);
132 linalg::populateLinalgBufferizePatterns(&context, converter, patterns);
133 populateShapeStructuralTypeConversionsAndLegality(&context, converter,
134 patterns, target);
135 scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
136 patterns, target);
137 // TODO(herhut): Move this legality configuration to bufferize itself?
138 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
139 auto inputs = op.getType().getInputs();
140 auto results = op.getType().getResults();
141 return converter.isLegal(inputs) && converter.isLegal(results) &&
142 converter.isLegal(&op.getBody());
143 });
144 auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
145 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
146 target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
147
148 if (failed(applyPartialConversion(getOperation(), target,
149 std::move(patterns))))
150 signalPassFailure();
151 }
152 };
153
154 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
155 // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon65e8b27c0111::FinalBufferizePass156 void getDependentDialects(DialectRegistry& registry) const override {
157 registry.insert<AffineDialect, scf::SCFDialect, shape::ShapeDialect,
158 tensor::TensorDialect, tf_framework::TFFrameworkDialect,
159 lmhlo::LmhloDialect>();
160 }
161
162 public:
runOnOperationmlir::kernel_gen::transforms::__anon65e8b27c0111::FinalBufferizePass163 void runOnOperation() override {
164 auto& context = getContext();
165 ConversionTarget target(context);
166 target.addLegalDialect<complex::ComplexDialect, scf::SCFDialect,
167 StandardOpsDialect, tensor::TensorDialect,
168 tf_framework::TFFrameworkDialect, AffineDialect,
169 shape::ShapeDialect, lmhlo::LmhloDialect,
170 linalg::LinalgDialect, math::MathDialect>();
171 target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
172
173 target.addIllegalDialect<mhlo::MhloDialect>();
174 target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
175 tensor::FromElementsOp, tensor::CastOp, TensorLoadOp,
176 TensorToMemrefOp>();
177 BufferizeTypeConverter converter;
178 auto typesAreLegal = [&converter](Operation* op) {
179 return converter.isLegal(op->getOperandTypes()) &&
180 converter.isLegal(op->getResultTypes());
181 };
182 target.addDynamicallyLegalOp<ConstantOp, DimOp, RankOp, SelectOp>(
183 typesAreLegal);
184
185 OwningRewritePatternList patterns;
186 populateTensorBufferizePatterns(&context, converter, patterns);
187 populateStdBufferizePatterns(&context, converter, patterns);
188 populateEliminateBufferizeMaterializationsPatterns(&context, converter,
189 patterns);
190 populateExtraStdBufferizePattern(&context, &converter, &patterns);
191 populateShapeStructuralTypeConversionsAndLegality(&context, converter,
192 patterns, target);
193 scf::populateSCFStructuralTypeConversionsAndLegality(&context, converter,
194 patterns, target);
195
196 auto module = getOperation();
197 if (failed(applyFullConversion(module, target, std::move(patterns)))) {
198 signalPassFailure();
199 }
200 }
201 };
202
203 } // namespace
204
CreateHloBufferizePass()205 std::unique_ptr<OperationPass<ModuleOp> > CreateHloBufferizePass() {
206 return std::make_unique<HloBufferizePass>();
207 }
208
CreateFinalBufferizePass()209 std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
210 return std::make_unique<FinalBufferizePass>();
211 }
212
213 } // namespace transforms
214 } // namespace kernel_gen
215 } // namespace mlir
216