• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Transforms/Passes.h"  // from @llvm-project
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
18 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
19 
20 namespace tensorflow {
21 namespace tfrt_compiler {
22 namespace {
23 
FunctionHasSideEffect(mlir::FuncOp func_op,llvm::DenseMap<mlir::FuncOp,bool> & function_side_effect)24 bool FunctionHasSideEffect(
25     mlir::FuncOp func_op,
26     llvm::DenseMap<mlir::FuncOp, bool>& function_side_effect) {
27   auto iter = function_side_effect.find(func_op);
28   if (iter != function_side_effect.end()) return iter->second;
29 
30   auto& block = func_op.front();
31 
32   auto op_has_side_effect = [&](mlir::Operation* op) {
33     if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
34       if (while_op.is_stateless()) return false;
35 
36       return FunctionHasSideEffect(while_op.cond_function(),
37                                    function_side_effect) ||
38              FunctionHasSideEffect(while_op.body_function(),
39                                    function_side_effect);
40     }
41 
42     if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
43       if (if_op.is_stateless()) return false;
44 
45       return FunctionHasSideEffect(if_op.else_function(),
46                                    function_side_effect) ||
47              FunctionHasSideEffect(if_op.then_function(), function_side_effect);
48     }
49 
50     // Though tf.Assert and tf.Timestamp are side-effecting, they do not
51     // interfere with any other side-effecting ops. For now, if control flow
52     // ops' callee functions contain them, we treat them as non-side-effecting.
53     if (llvm::isa<mlir::TF::AssertOp, mlir::TF::TimestampOp>(op)) return false;
54 
55     return !mlir::MemoryEffectOpInterface::hasNoEffect(op);
56   };
57 
58   // Speculatively setting the function to have no side effect to avoid infinite
59   // recursion. The correct side effect will be updated later once more
60   // operations in the block are checked.
61   function_side_effect[func_op] = false;
62 
63   for (mlir::Operation& op : block) {
64     if (op_has_side_effect(&op)) {
65       function_side_effect[func_op] = true;
66       return true;
67     }
68   }
69 
70   function_side_effect[func_op] = false;
71   return false;
72 }
73 
74 // This pass sets `is_stateless` attribute of tf.If and tf.While ops to true if
75 // their callee functions contains only non-side-effecting ops.
76 class OptimizeTfControlFlowSideEffectPass
77     : public mlir::PassWrapper<OptimizeTfControlFlowSideEffectPass,
78                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const79   llvm::StringRef getArgument() const final {
80     return "tfrt-optimize-tf-control-flow-side-effect";
81   }
getDescription() const82   llvm::StringRef getDescription() const final {
83     return "Set tf control flow ops to stateless if their callee functions "
84            "contains only non-side-effecting ops";
85   }
runOnOperation()86   void runOnOperation() override {
87     auto module = getOperation();
88     llvm::DenseMap<mlir::FuncOp, bool> function_side_effect;
89 
90     mlir::Builder builder(module.getContext());
91     module.walk([&](mlir::Operation* op) {
92       if (auto while_op = llvm::dyn_cast<mlir::TF::WhileOp>(op)) {
93         if (while_op.is_stateless()) return;
94 
95         if (!FunctionHasSideEffect(while_op.cond_function(),
96                                    function_side_effect) &&
97             !FunctionHasSideEffect(while_op.body_function(),
98                                    function_side_effect)) {
99           while_op->setAttr("is_stateless", builder.getBoolAttr(true));
100         }
101       }
102 
103       if (auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(op)) {
104         if (if_op.is_stateless()) return;
105 
106         if (!FunctionHasSideEffect(if_op.else_function(),
107                                    function_side_effect) &&
108             !FunctionHasSideEffect(if_op.then_function(),
109                                    function_side_effect)) {
110           if_op->setAttr("is_stateless", builder.getBoolAttr(true));
111         }
112       }
113     });
114   }
115 };
116 
117 }  // namespace
118 
119 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateOptimizeTfControlFlowSideEffectPass()120 CreateOptimizeTfControlFlowSideEffectPass() {
121   return std::make_unique<OptimizeTfControlFlowSideEffectPass>();
122 }
123 
124 static mlir::PassRegistration<OptimizeTfControlFlowSideEffectPass>
125     register_pass(CreateOptimizeTfControlFlowSideEffectPass);
126 
127 }  // namespace tfrt_compiler
128 }  // namespace tensorflow
129