1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Transforms/Passes.h" // from @llvm-project
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
18 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
19
20 namespace tensorflow {
21 namespace tfrt_compiler {
22 namespace {
23
FunctionHasSideEffect(mlir::FuncOp func_op,llvm::DenseMap<mlir::FuncOp,bool> & function_side_effect)24 bool FunctionHasSideEffect(
25 mlir::FuncOp func_op,
26 llvm::DenseMap<mlir::FuncOp, bool>& function_side_effect) {
27 auto iter = function_side_effect.find(func_op);
28 if (iter != function_side_effect.end()) return iter->second;
29
30 auto& block = func_op.front();
31
32 auto op_has_side_effect = [&](mlir::Operation* op) {
33 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
34 if (while_op.is_stateless()) return false;
35
36 return FunctionHasSideEffect(while_op.cond_function(),
37 function_side_effect) ||
38 FunctionHasSideEffect(while_op.body_function(),
39 function_side_effect);
40 }
41
42 if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
43 if (if_op.is_stateless()) return false;
44
45 return FunctionHasSideEffect(if_op.else_function(),
46 function_side_effect) ||
47 FunctionHasSideEffect(if_op.then_function(), function_side_effect);
48 }
49
50 // Though tf.Assert and tf.Timestamp are side-effecting, they do not
51 // interfere with any other side-effecting ops. For now, if control flow
52 // ops' callee functions contain them, we treat them as non-side-effecting.
53 if (llvm::isa<mlir::TF::AssertOp, mlir::TF::TimestampOp>(op)) return false;
54
55 return !mlir::MemoryEffectOpInterface::hasNoEffect(op);
56 };
57
58 // Speculatively setting the function to have no side effect to avoid infinite
59 // recursion. The correct side effect will be updated later once more
60 // operations in the block are checked.
61 function_side_effect[func_op] = false;
62
63 for (mlir::Operation& op : block) {
64 if (op_has_side_effect(&op)) {
65 function_side_effect[func_op] = true;
66 return true;
67 }
68 }
69
70 function_side_effect[func_op] = false;
71 return false;
72 }
73
74 // This pass sets `is_stateless` attribute of tf.If and tf.While ops to true if
75 // their callee functions contains only non-side-effecting ops.
76 class OptimizeTfControlFlowSideEffectPass
77 : public mlir::PassWrapper<OptimizeTfControlFlowSideEffectPass,
78 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const79 llvm::StringRef getArgument() const final {
80 return "tfrt-optimize-tf-control-flow-side-effect";
81 }
getDescription() const82 llvm::StringRef getDescription() const final {
83 return "Set tf control flow ops to stateless if their callee functions "
84 "contains only non-side-effecting ops";
85 }
runOnOperation()86 void runOnOperation() override {
87 auto module = getOperation();
88 llvm::DenseMap<mlir::FuncOp, bool> function_side_effect;
89
90 mlir::Builder builder(module.getContext());
91 module.walk([&](mlir::Operation* op) {
92 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
93 if (while_op.is_stateless()) return;
94
95 if (!FunctionHasSideEffect(while_op.cond_function(),
96 function_side_effect) &&
97 !FunctionHasSideEffect(while_op.body_function(),
98 function_side_effect)) {
99 while_op->setAttr("is_stateless", builder.getBoolAttr(true));
100 }
101 }
102
103 if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
104 if (if_op.is_stateless()) return;
105
106 if (!FunctionHasSideEffect(if_op.else_function(),
107 function_side_effect) &&
108 !FunctionHasSideEffect(if_op.then_function(),
109 function_side_effect)) {
110 if_op->setAttr("is_stateless", builder.getBoolAttr(true));
111 }
112 }
113 });
114 }
115 };
116
117 } // namespace
118
119 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateOptimizeTfControlFlowSideEffectPass()120 CreateOptimizeTfControlFlowSideEffectPass() {
121 return std::make_unique<OptimizeTfControlFlowSideEffectPass>();
122 }
123
124 static mlir::PassRegistration<OptimizeTfControlFlowSideEffectPass>
125 register_pass(CreateOptimizeTfControlFlowSideEffectPass);
126
127 } // namespace tfrt_compiler
128 } // namespace tensorflow
129