• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/Debug.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/IR/Block.h"  // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/Location.h"  // from @llvm-project
30 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
31 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
32 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
33 #include "mlir/IR/TypeRange.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
40 
41 constexpr llvm::StringRef
42     mlir::kernel_gen::tf_framework ::JITCompileFromStrOp::kJITEntryFunctionName;
43 
44 namespace mlir {
45 namespace kernel_gen {
46 namespace transforms {
47 namespace {
48 
49 static constexpr StringRef kEmitCInterfaceAttrName = "llvm.emit_c_interface";
50 
IsTFOperation(Operation * op)51 bool IsTFOperation(Operation *op) {
52   return op != nullptr &&
53          op->getDialect() ==
54              op->getContext()->getLoadedDialect<TF::TensorFlowDialect>();
55 }
56 
57 struct ModuleParameters {
58   llvm::ArrayRef<int64_t> tile_sizes;
59   llvm::ArrayRef<int64_t> unroll_factors;
60   int64_t max_supported_rank;
61   bool cpu_codegen;
62 };
63 
64 struct TFToJITInvocationsPattern : public RewritePattern {
TFToJITInvocationsPatternmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationsPattern65   explicit TFToJITInvocationsPattern(MLIRContext *ctx)
66       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
67 
matchAndRewritemlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationsPattern68   LogicalResult matchAndRewrite(Operation *op,
69                                 PatternRewriter &rewriter) const override {
70     // Apply to all TF ops except those that are already in a JIT-compiled
71     // region.
72     if (!IsTFOperation(op) || op->getParentOfType<tf_framework::JITCompileOp>())
73       return failure();
74 
75     // Find last TF op.
76     while (IsTFOperation(op->getNextNode())) op = op->getNextNode();
77 
78     // Find JIT compile region operands and results.
79     SmallVector<Operation *, 16> cluster;
80     llvm::SmallPtrSet<Value, 16> operand_set, result_set;
81     Operation *it = op;
82     while (IsTFOperation(it)) {
83       // Find results that escape the JIT compile region.
84       for (auto &use : it->getUses()) {
85         if (!llvm::is_contained(cluster, use.getOwner()))
86           result_set.insert(use.get());
87       }
88 
89       // Update JIT region operands and results.
90       for (Value v : it->getResults()) operand_set.erase(v);
91       for (Value v : it->getOperands()) operand_set.insert(v);
92 
93       cluster.push_back(it);
94       it = it->getPrevNode();
95     }
96 
97     // Introduce order to the operands and results.
98     auto operands = llvm::to_vector<16>(operand_set);
99     auto results = llvm::to_vector<16>(result_set);
100     auto operand_types = llvm::to_vector<16>(
101         llvm::map_range(operands, [](Value v) { return v.getType(); }));
102     auto result_types = llvm::to_vector<16>(
103         llvm::map_range(results, [](Value v) { return v.getType(); }));
104 
105     // Create the JIT compile op.
106     auto loc = op->getLoc();
107     auto jit_compile_op = rewriter.create<tf_framework::JITCompileOp>(
108         loc, rewriter.getType<tf_framework::JITCallableType>(), llvm::None);
109 
110     // Move the TF operations into the new op's body.
111     BlockAndValueMapping bvm;
112     {
113       OpBuilder::InsertionGuard guard(rewriter);
114       Block *block =
115           rewriter.createBlock(&jit_compile_op.body(), {}, operand_types);
116       for (auto it : llvm::zip(operands, block->getArguments()))
117         bvm.map(std::get<0>(it), std::get<1>(it));
118       rewriter.setInsertionPointToStart(block);
119       for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
120       auto mapped_results = llvm::to_vector<16>(
121           llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
122       rewriter.create<tf_framework::JITCompileYieldOp>(loc, TypeRange{},
123                                                        mapped_results);
124     }
125 
126     // Create JIT execute op.
127     auto jit_execute_op = rewriter.create<tf_framework::JITExecuteOp>(
128         loc, result_types, Value(), jit_compile_op.result(), operands);
129 
130     // Replace old TF ops with the new results.
131     for (auto it : llvm::zip(results, jit_execute_op.results()))
132       bvm.map(std::get<0>(it), std::get<1>(it));
133     for (Operation *it : cluster) {
134       if (it->getUses().empty()) {
135         rewriter.eraseOp(it);
136         continue;
137       }
138       auto replacements = llvm::to_vector<16>(llvm::map_range(
139           it->getResults(), [&](Value v) { return bvm.lookup(v); }));
140       rewriter.replaceOp(it, replacements);
141     }
142 
143     return success();
144   }
145 };
146 
147 struct PackJITCompileOpPattern
148     : public OpRewritePattern<tf_framework::JITCompileOp> {
149   using OpRewritePattern<tf_framework::JITCompileOp>::OpRewritePattern;
150 
PackJITCompileOpPatternmlir::kernel_gen::transforms::__anon9166ab170111::PackJITCompileOpPattern151   explicit PackJITCompileOpPattern(MLIRContext *ctx,
152                                    llvm::ArrayRef<StringRef> architectures,
153                                    llvm::ArrayRef<int64_t> tile_sizes,
154                                    llvm::ArrayRef<int64_t> unroll_factors,
155                                    int64_t max_supported_rank, bool enable_ftz,
156                                    bool cpu_codegen)
157       : OpRewritePattern<tf_framework::JITCompileOp>(ctx),
158         architectures(architectures),
159         tile_sizes(tile_sizes),
160         unroll_factors(unroll_factors),
161         max_supported_rank(max_supported_rank),
162         enable_ftz(enable_ftz),
163         cpu_codegen(cpu_codegen) {}
164 
matchAndRewritemlir::kernel_gen::transforms::__anon9166ab170111::PackJITCompileOpPattern165   LogicalResult matchAndRewrite(tf_framework::JITCompileOp op,
166                                 PatternRewriter &rewriter) const override {
167     Block *body = op.getBody();
168     auto yield_op =
169         llvm::cast<tf_framework::JITCompileYieldOp>(body->getTerminator());
170 
171     // Temporarily, build the module that would be JIT-compiled. This is only to
172     // obtain the serialized code attribute.
173     auto loc = op->getLoc();
174     OpBuilder tmp_module_builder(getContext(), rewriter.getListener());
175     auto jit_module = tmp_module_builder.create<ModuleOp>(loc);
176     tmp_module_builder.setInsertionPointToStart(jit_module.getBody());
177     auto jit_function = tmp_module_builder.create<FuncOp>(
178         loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName,
179         tmp_module_builder.getFunctionType(body->getArgumentTypes(),
180                                            yield_op->getOperandTypes()));
181     jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName,
182                           tmp_module_builder.getUnitAttr());
183     jit_function->setAttr(kEmitCInterfaceAttrName,
184                           tmp_module_builder.getUnitAttr());
185     jit_function.getBody().takeBody(op.getBodyRegion());
186     tmp_module_builder.setInsertionPointToEnd(&jit_function.getBody().front());
187     tmp_module_builder.create<ReturnOp>(loc, yield_op.result());
188     rewriter.eraseOp(yield_op);
189 
190     // Serialize JIT module.
191     std::string code;
192     llvm::raw_string_ostream ss(code);
193     jit_module.print(ss);
194 
195     // Finally, create the new JIT compile op.
196     rewriter.replaceOpWithNewOp<tf_framework::JITCompileFromStrOp>(
197         op, op->getResultTypes(), op.ctx(), rewriter.getStringAttr(code),
198         rewriter.getStrArrayAttr(architectures),
199         rewriter.getI64ArrayAttr(tile_sizes),
200         rewriter.getI64ArrayAttr(unroll_factors),
201         rewriter.getI64IntegerAttr(max_supported_rank),
202         rewriter.getBoolAttr(enable_ftz), rewriter.getBoolAttr(cpu_codegen));
203 
204     return success();
205   }
206 
207  private:
208   llvm::ArrayRef<StringRef> architectures;
209   llvm::ArrayRef<int64_t> tile_sizes;
210   llvm::ArrayRef<int64_t> unroll_factors;
211   int64_t max_supported_rank;
212   bool enable_ftz;
213   bool cpu_codegen;
214 };
215 
216 #define GEN_PASS_CLASSES
217 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
218 
219 struct TFToJITInvocationPass
220     : public TFToJITInvocationPassBase<TFToJITInvocationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass221   void getDependentDialects(DialectRegistry &registry) const override {
222     registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
223   }
TFToJITInvocationPassmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass224   explicit TFToJITInvocationPass(llvm::ArrayRef<std::string> architectures,
225                                  llvm::ArrayRef<int64_t> tile_sizes,
226                                  llvm::ArrayRef<int64_t> unroll_factors,
227                                  int64_t max_supported_rank, bool enable_ftz,
228                                  bool cpu_codegen) {
229     architectures_ = architectures;
230     tile_sizes_ = tile_sizes;
231     unroll_factors_ = unroll_factors;
232     max_supported_rank_ = max_supported_rank;
233     enable_ftz_ = enable_ftz;
234     cpu_codegen_ = cpu_codegen;
235   }
236 
runOnFunctionmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass237   void runOnFunction() override {
238     MLIRContext *ctx = &getContext();
239     RewritePatternSet patterns(ctx);
240     auto architecture_refs = llvm::to_vector<16>(llvm::map_range(
241         architectures_, [](std::string &arch) { return StringRef(arch); }));
242     PopulateTFToJITInvocationPatterns(
243         ctx, &patterns, architecture_refs, tile_sizes_, unroll_factors_,
244         max_supported_rank_, enable_ftz_, cpu_codegen_);
245     if (failed(
246             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
247       return signalPassFailure();
248     }
249   }
250 };
251 
252 }  // namespace
253 
PopulateTFToJITInvocationPatterns(MLIRContext * ctx,RewritePatternSet * patterns,llvm::ArrayRef<StringRef> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool cpu_codegen)254 void PopulateTFToJITInvocationPatterns(MLIRContext *ctx,
255                                        RewritePatternSet *patterns,
256                                        llvm::ArrayRef<StringRef> architectures,
257                                        llvm::ArrayRef<int64_t> tile_sizes,
258                                        llvm::ArrayRef<int64_t> unroll_factors,
259                                        int64_t max_supported_rank,
260                                        bool enable_ftz, bool cpu_codegen) {
261   patterns->insert<TFToJITInvocationsPattern>(ctx);
262   patterns->insert<PackJITCompileOpPattern>(ctx, architectures, tile_sizes,
263                                             unroll_factors, max_supported_rank,
264                                             enable_ftz, cpu_codegen);
265 }
266 
CreateTFToJITInvocationPass(llvm::ArrayRef<std::string> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool cpu_codegen)267 std::unique_ptr<FunctionPass> CreateTFToJITInvocationPass(
268     llvm::ArrayRef<std::string> architectures,
269     llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
270     int64_t max_supported_rank, bool enable_ftz, bool cpu_codegen) {
271   return std::make_unique<TFToJITInvocationPass>(
272       architectures, tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
273       cpu_codegen);
274 }
275 
276 }  // namespace transforms
277 }  // namespace kernel_gen
278 }  // namespace mlir
279