• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
25 #include "mlir/Dialect/Complex/IR/Complex.h"  // from @llvm-project
26 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
27 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
28 #include "mlir/Dialect/Math/IR/Math.h"  // from @llvm-project
29 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
30 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
31 #include "mlir/Dialect/SCF/Transforms.h"  // from @llvm-project
32 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
33 #include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
34 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
35 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"  // from @llvm-project
36 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"  // from @llvm-project
37 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
38 #include "mlir/Dialect/Tensor/Transforms/Passes.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
41 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
42 #include "mlir/IR/Operation.h"  // from @llvm-project
43 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
44 #include "mlir/IR/Visitors.h"  // from @llvm-project
45 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
46 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
47 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
48 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
49 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
53 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
54 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
55 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
56 
57 namespace mlir {
58 namespace kernel_gen {
59 namespace transforms {
60 namespace {
61 
62 #define GEN_PASS_CLASSES
63 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
64 
65 /// A helper type converter class that automatically populates the relevant
66 /// materializations and type conversions for bufferization.
67 
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)68 static Value materializeTensorLoad(OpBuilder& builder, TensorType type,
69                                    ValueRange inputs, Location loc) {
70   assert(inputs.size() == 1);
71   assert(inputs[0].getType().isa<BaseMemRefType>());
72   return builder.create<memref::TensorLoadOp>(loc, type, inputs[0]);
73 }
74 
75 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
76 class CustomBufferizeTypeConverter : public BufferizeTypeConverter {
77  public:
CustomBufferizeTypeConverter()78   CustomBufferizeTypeConverter() {
79     // Keep all types unchanged.
80     addConversion([](Type type) { return type; });
81     // Convert RankedTensorType to MemRefType.
82     addConversion([](RankedTensorType type) -> Type {
83       return MemRefType::get(type.getShape(), type.getElementType());
84     });
85     // Convert UnrankedTensorType to UnrankedMemRefType.
86     addConversion([](UnrankedTensorType type) -> Type {
87       return UnrankedMemRefType::get(type.getElementType(), 0);
88     });
89     addArgumentMaterialization(materializeTensorLoad);
90     addSourceMaterialization(materializeTensorLoad);
91     addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
92                                 ValueRange inputs, Location loc) -> Value {
93       assert(inputs.size() == 1);
94       // Target materialization is invoked if the new operand type does not
95       // match the expected type. A special case is when the new operand type is
96       // a memref with a specified layout, i.e. non-empty affine map.
97       // TODO(pifon) : Change how target materialization is invoked in dialect
98       // conversion.
99       if (auto memref_type = inputs[0].getType().dyn_cast<MemRefType>()) {
100         assert(!memref_type.getAffineMaps().empty());
101         return inputs[0];
102       }
103       assert(inputs[0].getType().isa<TensorType>());
104       return builder.create<memref::BufferCastOp>(loc, type, inputs[0]);
105     });
106   }
107 };
108 
109 struct ComputeOpAndFuncBufferizePass
110     : public ComputeOpAndFuncBufferizePassBase<ComputeOpAndFuncBufferizePass> {
111   // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon5fbbc5010111::ComputeOpAndFuncBufferizePass112   void getDependentDialects(DialectRegistry& registry) const override {
113     registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect>();
114   }
115 
116  public:
runOnOperationmlir::kernel_gen::transforms::__anon5fbbc5010111::ComputeOpAndFuncBufferizePass117   void runOnOperation() override {
118     RewritePatternSet patterns(&getContext());
119     auto& context = getContext();
120     ConversionTarget target(context);
121     target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
122                            memref::MemRefDialect, StandardOpsDialect,
123                            tensor::TensorDialect, math::MathDialect>();
124     target.addLegalOp<UnrealizedConversionCastOp>();
125     target.addIllegalDialect<mhlo::MhloDialect>();
126     target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
127 
128     CustomBufferizeTypeConverter converter;
129     mhlo::RemoveSignTypeConverter remove_sign_converter;
130 
131     // Configure bufferize pattern for functions and lhlo.
132     mhlo::populateHLOToMemrefConversionPattern(
133         &converter, &remove_sign_converter, &patterns,
134         /*enforce_identity_map=*/false);
135     populateFuncOpTypeConversionPattern(patterns, converter);
136     populateCallOpTypeConversionPattern(patterns, converter);
137     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
138     populateReturnOpTypeConversionPattern(patterns, converter);
139 
140     // Configure legality and structural patterns.
141     populateBufferizeMaterializationLegality(target);
142     linalg::populateLinalgBufferizePatterns(converter, patterns);
143     populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
144                                                       target);
145     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
146                                                          target);
147     // TODO(herhut): Move this legality configuration to bufferize itself?
148     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
149       auto inputs = op.getType().getInputs();
150       auto results = op.getType().getResults();
151       return converter.isLegal(inputs) && converter.isLegal(results) &&
152              converter.isLegal(&op.getBody());
153     });
154     auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
155     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOp);
156     target.addDynamicallyLegalOp<CallOp, ReturnOp>(isLegalOp);
157 
158     if (failed(applyPartialConversion(getOperation(), target,
159                                       std::move(patterns))))
160       signalPassFailure();
161   }
162 };
163 
164 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
165   // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::kernel_gen::transforms::__anon5fbbc5010111::FinalBufferizePass166   void getDependentDialects(DialectRegistry& registry) const override {
167     registry.insert<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
168                     shape::ShapeDialect, tensor::TensorDialect,
169                     tf_framework::TFFrameworkDialect, lmhlo::LmhloDialect>();
170   }
171 
172  public:
runOnOperationmlir::kernel_gen::transforms::__anon5fbbc5010111::FinalBufferizePass173   void runOnOperation() override {
174     auto& context = getContext();
175     ConversionTarget target(context);
176     target.addLegalDialect<
177         complex::ComplexDialect, memref::MemRefDialect, StandardOpsDialect,
178         scf::SCFDialect, tensor::TensorDialect,
179         tf_framework::TFFrameworkDialect, AffineDialect, shape::ShapeDialect,
180         lmhlo::LmhloDialect, linalg::LinalgDialect, math::MathDialect,
181         vector::VectorDialect>();
182     target.addLegalOp<FuncOp, ModuleOp>();
183 
184     target.addIllegalDialect<mhlo::MhloDialect>();
185     target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
186                         tensor::FromElementsOp, tensor::CastOp, tensor::DimOp,
187                         chlo::MinimumBroadcastShapesOp, memref::TensorLoadOp,
188                         memref::BufferCastOp, linalg::TensorExpandShapeOp,
189                         linalg::TensorCollapseShapeOp>();
190     BufferizeTypeConverter converter;
191     auto typesAreLegal = [&converter](Operation* op) {
192       return converter.isLegal(op->getOperandTypes()) &&
193              converter.isLegal(op->getResultTypes());
194     };
195     target.addDynamicallyLegalOp<ConstantOp, IndexCastOp, RankOp, SelectOp,
196                                  tf_framework::JITExecuteOp>(typesAreLegal);
197 
198     RewritePatternSet patterns(&getContext());
199     linalg::populateLinalgBufferizePatterns(converter, patterns);
200     populateTensorBufferizePatterns(converter, patterns);
201     populateStdBufferizePatterns(converter, patterns);
202     populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
203     populateExtraBufferizePatterns(&getContext(), &converter, &patterns);
204     populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
205                                                       target);
206     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
207                                                          target);
208 
209     auto module = getOperation();
210     if (failed(applyFullConversion(module, target, std::move(patterns)))) {
211       signalPassFailure();
212     }
213   }
214 };
215 
216 }  // namespace
217 
218 std::unique_ptr<OperationPass<ModuleOp> >
CreateComputeOpAndFuncBufferizePass()219 CreateComputeOpAndFuncBufferizePass() {
220   return std::make_unique<ComputeOpAndFuncBufferizePass>();
221 }
222 
CreateFinalBufferizePass()223 std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
224   return std::make_unique<FinalBufferizePass>();
225 }
226 
227 }  // namespace transforms
228 }  // namespace kernel_gen
229 }  // namespace mlir
230