1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
25 #include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
26 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
27 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
28 #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project
29 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
30 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
31 #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
32 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
33 #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
34 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
35 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" // from @llvm-project
36 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project
37 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
38 #include "mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
41 #include "mlir/IR/MLIRContext.h" // from @llvm-project
42 #include "mlir/IR/Operation.h" // from @llvm-project
43 #include "mlir/IR/PatternMatch.h" // from @llvm-project
44 #include "mlir/IR/Visitors.h" // from @llvm-project
45 #include "mlir/Transforms/Bufferize.h" // from @llvm-project
46 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
49 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
53 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
54 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
56
57 namespace mlir {
58 namespace kernel_gen {
59 namespace transforms {
60 namespace {
61
62 #define GEN_PASS_CLASSES
63 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
64
65 /// A helper type converter class that automatically populates the relevant
66 /// materializations and type conversions for bufferization.
67
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)68 static Value materializeTensorLoad(OpBuilder& builder, TensorType type,
69 ValueRange inputs, Location loc) {
70 assert(inputs.size() == 1);
71 assert(inputs[0].getType().isa<BaseMemRefType>());
72 return builder.create<memref::TensorLoadOp>(loc, type, inputs[0]);
73 }
74
75 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
76 class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
77 public:
CustomBufferizeTypeConverter()78 CustomBufferizeTypeConverter() {
79 // Keep all types unchanged.
80 addConversion([](Type type) { return type; });
81 // Convert RankedTensorType to MemRefType.
82 addConversion([](RankedTensorType type) -> Type {
83 return MemRefType::get(type.getShape(), type.getElementType());
84 });
85 // Convert UnrankedTensorType to UnrankedMemRefType.
86 addConversion([](UnrankedTensorType type) -> Type {
87 return UnrankedMemRefType::get(type.getElementType(), 0);
88 });
89 addArgumentMaterialization(materializeTensorLoad);
90 addSourceMaterialization(materializeTensorLoad);
91 addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
92 ValueRange inputs, Location loc) -> Value {
93 assert(inputs.size() == 1);
94 // Target materialization is invoked if the new operand type does not
95 // match the expected type. A special case is when the new operand type is
96 // a memref with a specified layout, i.e. non-empty affine map.
97 // TODO(pifon) : Change how target materialization is invoked in dialect
98 // conversion.
99 if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
100 assert(!memref_type.getAffineMaps().empty());
101 return inputs[0];
102 }
103 assert(inputs[0].getType().isa<TensorType>());
104 return builder.create<memref::BufferCastOp>(loc, type, inputs[0]);
105 });
106 }
107 };
108
109 struct ComputeOpAndFuncBufferizePass
110 : public ComputeOpAndFuncBufferizePassBase<ComputeOpAndFuncBufferizePass> {
111 // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon5fbbc5010111::ComputeOpAndFuncBufferizePass112 void getDependentDialects(DialectRegistry& registry) const override {
113 registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect>();
114 }
115
116 public:
runOnOperationmlir::kernel_gen::transforms::__anon5fbbc5010111::ComputeOpAndFuncBufferizePass117 void runOnOperation() override {
118 RewritePatternSet patterns(&getContext());
119 auto& context = getContext();
120 ConversionTarget target(context);
121 target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
122 memref::MemRefDialect, StandardOpsDialect,
123 tensor::TensorDialect, math::MathDialect>();
124 target.addLegalOp<UnrealizedConversionCastOp>();
125 target.addIllegalDialect<mhlo::MhloDialect>();
126 target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
127
128 CustomBufferizeTypeConverter converter;
129 mhlo::RemoveSignTypeConverter remove_sign_converter;
130
131 // Configure bufferize pattern for functions and lhlo.
132 mhlo::populateHLOToMemrefConversionPattern(
133 &converter, &remove_sign_converter, &patterns,
134 /*enforce_identity_map=*/false);
135 populateFuncOpTypeConversionPattern(patterns, converter);
136 populateCallOpTypeConversionPattern(patterns, converter);
137 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
138 populateReturnOpTypeConversionPattern(patterns, converter);
139
140 // Configure legality and structural patterns.
141 populateBufferizeMaterializationLegality(target);
142 linalg::populateLinalgBufferizePatterns(converter, patterns);
143 populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
144 target);
145 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
146 target);
147 // TODO(herhut): Move this legality configuration to bufferize itself?
148 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
149 auto inputs = op.getType().getInputs();
150 auto results = op.getType().getResults();
151 return converter.isLegal(inputs) && converter.isLegal(results) &&
152 converter.isLegal(&op.getBody());
153 });
154 auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
155 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
156 target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
157
158 if (failed(applyPartialConversion(getOperation(), target,
159 std::move(patterns))))
160 signalPassFailure();
161 }
162 };
163
164 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
165 // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon5fbbc5010111::FinalBufferizePass166 void getDependentDialects(DialectRegistry& registry) const override {
167 registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
168 shape::ShapeDialect, tensor::TensorDialect,
169 tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
170 }
171
172 public:
runOnOperationmlir::kernel_gen::transforms::__anon5fbbc5010111::FinalBufferizePass173 void runOnOperation() override {
174 auto& context = getContext();
175 ConversionTarget target(context);
176 target.addLegalDialect<
177 complex::ComplexDialect, memref::MemRefDialect, StandardOpsDialect,
178 scf::SCFDialect, tensor::TensorDialect,
179 tf_framework::TFFrameworkDialect, AffineDialect, shape::ShapeDialect,
180 lmhlo::LmhloDialect, linalg::LinalgDialect, math::MathDialect,
181 vector::VectorDialect>();
182 target.addLegalOp<FuncOp, ModuleOp>();
183
184 target.addIllegalDialect<mhlo::MhloDialect>();
185 target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
186 tensor::FromElementsOp, tensor::CastOp, tensor::DimOp,
187 chlo::MinimumBroadcastShapesOp, memref::TensorLoadOp,
188 memref::BufferCastOp, linalg::TensorExpandShapeOp,
189 linalg::TensorCollapseShapeOp>();
190 BufferizeTypeConverter converter;
191 auto typesAreLegal = [&converter](Operation* op) {
192 return converter.isLegal(op->getOperandTypes()) &&
193 converter.isLegal(op->getResultTypes());
194 };
195 target.addDynamicallyLegalOp<ConstantOp, IndexCastOp, RankOp, SelectOp,
196 tf_framework::JITExecuteOp>(typesAreLegal);
197
198 RewritePatternSet patterns(&getContext());
199 linalg::populateLinalgBufferizePatterns(converter, patterns);
200 populateTensorBufferizePatterns(converter, patterns);
201 populateStdBufferizePatterns(converter, patterns);
202 populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
203 populateExtraBufferizePatterns(&getContext(), &converter, &patterns);
204 populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
205 target);
206 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
207 target);
208
209 auto module = getOperation();
210 if (failed(applyFullConversion(module, target, std::move(patterns)))) {
211 signalPassFailure();
212 }
213 }
214 };
215
216 } // namespace
217
218 std::unique_ptr<OperationPass<ModuleOp> >
CreateComputeOpAndFuncBufferizePass()219 CreateComputeOpAndFuncBufferizePass() {
220 return std::make_unique<ComputeOpAndFuncBufferizePass>();
221 }
222
CreateFinalBufferizePass()223 std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
224 return std::make_unique<FinalBufferizePass>();
225 }
226
227 } // namespace transforms
228 } // namespace kernel_gen
229 } // namespace mlir
230