1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_ 18 19 #include "tensorflow/core/common_runtime/optimization_registry.h" 20 21 namespace tensorflow { 22 23 // Move control flow dependencies in functional control flow to chains. 24 // Chains are extra loop variables that serve as tokens for wiring control 25 // dependencies across loop iterations at a finer granularity, compared to just 26 // a single barrier at the end of each iteration. This enables the 27 // parallel_iterations feature for tf.while_loop. 28 // 29 // One separate chain is added for each of the body function's `control_ret`. 30 // 31 // For example: 32 // 33 // while i > 0: 34 // r = v.read_value() 35 // s += expensive_operation(r) 36 // assign = v.assign_add(1) # control: r 37 // i += 1 38 // 39 // The loop above can safely compute `r` and `assign` ahead of `s`, by the 40 // as-if rule. The separate switch/merge nodes that the loop lowers into support 41 // that. 42 // This transformation enables that to happen by rewriting the loop as follows: 43 // 44 // chain = 0.0 45 // while i > 0: 46 // r = v.read_value() # control: chain 47 // s += expensive_operation(r) 48 // assign = v.assign_add(1) # control: r 49 // i += 1 50 // chain = identity(chain) # control: assign 51 // 52 // This only rewires dependencies which need to cross scope boundaries, as the 53 // switch/merge lowering process has no other way of dealing correctly with 54 // those. 55 // 56 // This pass is best-effort and conservative, requiring attributes set by 57 // tf.while_loop and automatic_control_dependencies. When the required 58 // attributes are missing for a particular While node, no change is made to 59 // that node. Other While nodes are still processed if they do have the needed 60 // annotations. 61 // The pass can also be toggled by omitting the `_stateful_parallelism=True` 62 // attribute on the While node. 63 // When the pass returns with error, the graph is left in an invalid state. 64 // If successful, this pass also clears the body function's control_ret, 65 // which in effect removes the hard barrier that gates each loop iteration. 66 // 67 // 68 // TODO(mdan): Can we define that more formally? 69 class ControlFlowDepsToChainsPass : public GraphOptimizationPass { 70 public: 71 Status Run(const GraphOptimizationPassOptions& options) override; 72 }; 73 74 } // namespace tensorflow 75 76 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONTROL_FLOW_DEPS_TO_CHAINS_H_ 77