1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/Debug.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
25 #include "mlir/IR/Block.h" // from @llvm-project
26 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/Location.h" // from @llvm-project
30 #include "mlir/IR/MLIRContext.h" // from @llvm-project
31 #include "mlir/IR/OperationSupport.h" // from @llvm-project
32 #include "mlir/IR/PatternMatch.h" // from @llvm-project
33 #include "mlir/IR/TypeRange.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
37 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
38 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
39 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
40
41 constexpr llvm::StringRef
42 mlir::kernel_gen::tf_framework ::JITCompileFromStrOp::kJITEntryFunctionName;
43
44 namespace mlir {
45 namespace kernel_gen {
46 namespace transforms {
47 namespace {
48
49 static constexpr StringRef kEmitCInterfaceAttrName = "llvm.emit_c_interface";
50
IsTFOperation(Operation * op)51 bool IsTFOperation(Operation *op) {
52 return op != nullptr &&
53 op->getDialect() ==
54 op->getContext()->getLoadedDialect<TF::TensorFlowDialect>();
55 }
56
57 struct ModuleParameters {
58 llvm::ArrayRef<int64_t> tile_sizes;
59 llvm::ArrayRef<int64_t> unroll_factors;
60 int64_t max_supported_rank;
61 bool cpu_codegen;
62 };
63
64 struct TFToJITInvocationsPattern : public RewritePattern {
TFToJITInvocationsPatternmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationsPattern65 explicit TFToJITInvocationsPattern(MLIRContext *ctx)
66 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
67
matchAndRewritemlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationsPattern68 LogicalResult matchAndRewrite(Operation *op,
69 PatternRewriter &rewriter) const override {
70 // Apply to all TF ops except those that are already in a JIT-compiled
71 // region.
72 if (!IsTFOperation(op) || op->getParentOfType<tf_framework::JITCompileOp>())
73 return failure();
74
75 // Find last TF op.
76 while (IsTFOperation(op->getNextNode())) op = op->getNextNode();
77
78 // Find JIT compile region operands and results.
79 SmallVector<Operation *, 16> cluster;
80 llvm::SmallPtrSet<Value, 16> operand_set, result_set;
81 Operation *it = op;
82 while (IsTFOperation(it)) {
83 // Find results that escape the JIT compile region.
84 for (auto &use : it->getUses()) {
85 if (!llvm::is_contained(cluster, use.getOwner()))
86 result_set.insert(use.get());
87 }
88
89 // Update JIT region operands and results.
90 for (Value v : it->getResults()) operand_set.erase(v);
91 for (Value v : it->getOperands()) operand_set.insert(v);
92
93 cluster.push_back(it);
94 it = it->getPrevNode();
95 }
96
97 // Introduce order to the operands and results.
98 auto operands = llvm::to_vector<16>(operand_set);
99 auto results = llvm::to_vector<16>(result_set);
100 auto operand_types = llvm::to_vector<16>(
101 llvm::map_range(operands, [](Value v) { return v.getType(); }));
102 auto result_types = llvm::to_vector<16>(
103 llvm::map_range(results, [](Value v) { return v.getType(); }));
104
105 // Create the JIT compile op.
106 auto loc = op->getLoc();
107 auto jit_compile_op = rewriter.create<tf_framework::JITCompileOp>(
108 loc, rewriter.getType<tf_framework::JITCallableType>(), llvm::None);
109
110 // Move the TF operations into the new op's body.
111 BlockAndValueMapping bvm;
112 {
113 OpBuilder::InsertionGuard guard(rewriter);
114 Block *block =
115 rewriter.createBlock(&jit_compile_op.body(), {}, operand_types);
116 for (auto it : llvm::zip(operands, block->getArguments()))
117 bvm.map(std::get<0>(it), std::get<1>(it));
118 rewriter.setInsertionPointToStart(block);
119 for (Operation *it : llvm::reverse(cluster)) rewriter.clone(*it, bvm);
120 auto mapped_results = llvm::to_vector<16>(
121 llvm::map_range(results, [&](Value v) { return bvm.lookup(v); }));
122 rewriter.create<tf_framework::JITCompileYieldOp>(loc, TypeRange{},
123 mapped_results);
124 }
125
126 // Create JIT execute op.
127 auto jit_execute_op = rewriter.create<tf_framework::JITExecuteOp>(
128 loc, result_types, Value(), jit_compile_op.result(), operands);
129
130 // Replace old TF ops with the new results.
131 for (auto it : llvm::zip(results, jit_execute_op.results()))
132 bvm.map(std::get<0>(it), std::get<1>(it));
133 for (Operation *it : cluster) {
134 if (it->getUses().empty()) {
135 rewriter.eraseOp(it);
136 continue;
137 }
138 auto replacements = llvm::to_vector<16>(llvm::map_range(
139 it->getResults(), [&](Value v) { return bvm.lookup(v); }));
140 rewriter.replaceOp(it, replacements);
141 }
142
143 return success();
144 }
145 };
146
147 struct PackJITCompileOpPattern
148 : public OpRewritePattern<tf_framework::JITCompileOp> {
149 using OpRewritePattern<tf_framework::JITCompileOp>::OpRewritePattern;
150
PackJITCompileOpPatternmlir::kernel_gen::transforms::__anon9166ab170111::PackJITCompileOpPattern151 explicit PackJITCompileOpPattern(MLIRContext *ctx,
152 llvm::ArrayRef<StringRef> architectures,
153 llvm::ArrayRef<int64_t> tile_sizes,
154 llvm::ArrayRef<int64_t> unroll_factors,
155 int64_t max_supported_rank, bool enable_ftz,
156 bool cpu_codegen)
157 : OpRewritePattern<tf_framework::JITCompileOp>(ctx),
158 architectures(architectures),
159 tile_sizes(tile_sizes),
160 unroll_factors(unroll_factors),
161 max_supported_rank(max_supported_rank),
162 enable_ftz(enable_ftz),
163 cpu_codegen(cpu_codegen) {}
164
matchAndRewritemlir::kernel_gen::transforms::__anon9166ab170111::PackJITCompileOpPattern165 LogicalResult matchAndRewrite(tf_framework::JITCompileOp op,
166 PatternRewriter &rewriter) const override {
167 Block *body = op.getBody();
168 auto yield_op =
169 llvm::cast<tf_framework::JITCompileYieldOp>(body->getTerminator());
170
171 // Temporarily, build the module that would be JIT-compiled. This is only to
172 // obtain the serialized code attribute.
173 auto loc = op->getLoc();
174 OpBuilder tmp_module_builder(getContext(), rewriter.getListener());
175 auto jit_module = tmp_module_builder.create<ModuleOp>(loc);
176 tmp_module_builder.setInsertionPointToStart(jit_module.getBody());
177 auto jit_function = tmp_module_builder.create<FuncOp>(
178 loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName,
179 tmp_module_builder.getFunctionType(body->getArgumentTypes(),
180 yield_op->getOperandTypes()));
181 jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName,
182 tmp_module_builder.getUnitAttr());
183 jit_function->setAttr(kEmitCInterfaceAttrName,
184 tmp_module_builder.getUnitAttr());
185 jit_function.getBody().takeBody(op.getBodyRegion());
186 tmp_module_builder.setInsertionPointToEnd(&jit_function.getBody().front());
187 tmp_module_builder.create<ReturnOp>(loc, yield_op.result());
188 rewriter.eraseOp(yield_op);
189
190 // Serialize JIT module.
191 std::string code;
192 llvm::raw_string_ostream ss(code);
193 jit_module.print(ss);
194
195 // Finally, create the new JIT compile op.
196 rewriter.replaceOpWithNewOp<tf_framework::JITCompileFromStrOp>(
197 op, op->getResultTypes(), op.ctx(), rewriter.getStringAttr(code),
198 rewriter.getStrArrayAttr(architectures),
199 rewriter.getI64ArrayAttr(tile_sizes),
200 rewriter.getI64ArrayAttr(unroll_factors),
201 rewriter.getI64IntegerAttr(max_supported_rank),
202 rewriter.getBoolAttr(enable_ftz), rewriter.getBoolAttr(cpu_codegen));
203
204 return success();
205 }
206
207 private:
208 llvm::ArrayRef<StringRef> architectures;
209 llvm::ArrayRef<int64_t> tile_sizes;
210 llvm::ArrayRef<int64_t> unroll_factors;
211 int64_t max_supported_rank;
212 bool enable_ftz;
213 bool cpu_codegen;
214 };
215
216 #define GEN_PASS_CLASSES
217 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
218
219 struct TFToJITInvocationPass
220 : public TFToJITInvocationPassBase<TFToJITInvocationPass> {
getDependentDialectsmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass221 void getDependentDialects(DialectRegistry ®istry) const override {
222 registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
223 }
TFToJITInvocationPassmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass224 explicit TFToJITInvocationPass(llvm::ArrayRef<std::string> architectures,
225 llvm::ArrayRef<int64_t> tile_sizes,
226 llvm::ArrayRef<int64_t> unroll_factors,
227 int64_t max_supported_rank, bool enable_ftz,
228 bool cpu_codegen) {
229 architectures_ = architectures;
230 tile_sizes_ = tile_sizes;
231 unroll_factors_ = unroll_factors;
232 max_supported_rank_ = max_supported_rank;
233 enable_ftz_ = enable_ftz;
234 cpu_codegen_ = cpu_codegen;
235 }
236
runOnFunctionmlir::kernel_gen::transforms::__anon9166ab170111::TFToJITInvocationPass237 void runOnFunction() override {
238 MLIRContext *ctx = &getContext();
239 RewritePatternSet patterns(ctx);
240 auto architecture_refs = llvm::to_vector<16>(llvm::map_range(
241 architectures_, [](std::string &arch) { return StringRef(arch); }));
242 PopulateTFToJITInvocationPatterns(
243 ctx, &patterns, architecture_refs, tile_sizes_, unroll_factors_,
244 max_supported_rank_, enable_ftz_, cpu_codegen_);
245 if (failed(
246 applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
247 return signalPassFailure();
248 }
249 }
250 };
251
252 } // namespace
253
PopulateTFToJITInvocationPatterns(MLIRContext * ctx,RewritePatternSet * patterns,llvm::ArrayRef<StringRef> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool cpu_codegen)254 void PopulateTFToJITInvocationPatterns(MLIRContext *ctx,
255 RewritePatternSet *patterns,
256 llvm::ArrayRef<StringRef> architectures,
257 llvm::ArrayRef<int64_t> tile_sizes,
258 llvm::ArrayRef<int64_t> unroll_factors,
259 int64_t max_supported_rank,
260 bool enable_ftz, bool cpu_codegen) {
261 patterns->insert<TFToJITInvocationsPattern>(ctx);
262 patterns->insert<PackJITCompileOpPattern>(ctx, architectures, tile_sizes,
263 unroll_factors, max_supported_rank,
264 enable_ftz, cpu_codegen);
265 }
266
CreateTFToJITInvocationPass(llvm::ArrayRef<std::string> architectures,llvm::ArrayRef<int64_t> tile_sizes,llvm::ArrayRef<int64_t> unroll_factors,int64_t max_supported_rank,bool enable_ftz,bool cpu_codegen)267 std::unique_ptr<FunctionPass> CreateTFToJITInvocationPass(
268 llvm::ArrayRef<std::string> architectures,
269 llvm::ArrayRef<int64_t> tile_sizes, llvm::ArrayRef<int64_t> unroll_factors,
270 int64_t max_supported_rank, bool enable_ftz, bool cpu_codegen) {
271 return std::make_unique<TFToJITInvocationPass>(
272 architectures, tile_sizes, unroll_factors, max_supported_rank, enable_ftz,
273 cpu_codegen);
274 }
275
276 } // namespace transforms
277 } // namespace kernel_gen
278 } // namespace mlir
279