1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ 18 19 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" 20 #include "tensorflow/compiler/xla/status_macros.h" 21 #include "tensorflow/core/common_runtime/optimization_registry.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/graph/graph.h" 24 25 namespace tensorflow { 26 27 // Transformation that converts tf.while_loop() loops into functional While 28 // operators and tf.cond() conditionals into function If operators, suitable for 29 // XLA compilation. 30 // 31 // If `node_filter` is defined, then only loops and conditions for whose 32 // nodes `node_filter` returns true are functionalized. 33 34 // If `include_functions` is true, then loops and conditions inside of functions 35 // that are associated with nodes in `graph` (e.g., a function called from a 36 // node in `graph`) are also functionalized, otherwise they are not. 37 // This also handles transitive cases, e.g., a function body will be 38 // functionalized when it is called in another function that is called by some 39 // node in `graph` (and so on). The node filter also applies here. 40 // 41 // Precondition: 42 // For any node in a loop or condition for which `node_filter` returns true, 43 // all nodes inside of the same loop or condition must also return true 44 // (including nodes in other nested loops and conditions inside of that loop or 45 // condition). 46 // This means that a "not to be functionalized" loop or condition is not allowed 47 // inside a "to be functionalized" loop or condition. 48 // 49 // The user of this function is responsible for using a node filter that 50 // satisfies the above conditions. 51 Status FunctionalizeControlFlow(Graph* graph, 52 FunctionLibraryDefinition* library, 53 const NodeFilter& node_filter = {}, 54 bool include_functions = false); 55 56 Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, 57 FunctionLibraryDefinition* library, 58 const NodeFilter& node_filter = {}, 59 bool include_functions = false); 60 61 // This pass looks at the graph, and turns V1 control flow structure 62 // (Switch/Merge/etc.) into V2 control flow structure (If/While). 63 class FunctionalizeControlFlowPass : public GraphOptimizationPass { 64 public: 65 Status Run(const GraphOptimizationPassOptions& options) override; 66 }; 67 68 // Same as the above but only modifies functions that will be executed by XLA. 69 class FunctionalizeControlFlowForXlaPass : public GraphOptimizationPass { 70 public: 71 Status Run(const GraphOptimizationPassOptions& options) override; 72 }; 73 74 } // namespace tensorflow 75 76 #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ 77