1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ 17 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ 18 19 #include <string> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "absl/types/optional.h" 24 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" 25 #include "tensorflow/core/common_runtime/function.h" 26 #include "tensorflow/core/graph/graph.h" 27 28 namespace tensorflow { 29 namespace tpu { 30 31 struct LoopArgInfo { 32 std::string enter_node_name; 33 // Exit nodes are optional for loop invariant while loop args. 34 absl::optional<std::string> exit_node_name; 35 }; 36 37 struct HostTrainingLoopInfo { 38 // Name and attribute information about the function in which 39 // host training loop is included. If host training loop is not 40 // inside a function call, then `function_name` and `function_attrs` 41 // are nullopt. 42 absl::optional<std::string> encapsulating_function_name; 43 absl::optional<AttrValueMap> encapsulating_function_attrs; 44 45 // TPU Compile node as within a host training loop. 46 std::string compile_node_name; 47 48 // Name of the while loop in which TPU compile op is located. 49 std::string while_loop_name; 50 51 // Name of the node that represents loop condition. 52 std::string loop_cond_node_name; 53 54 // Exit and Enter node names for each loop arguments. 55 std::vector<LoopArgInfo> loop_arguments; 56 57 std::unordered_set<Node*> loop_nodes; // NOLINT 58 }; 59 60 // Walks through the `graph`, recursively if functional nodes exist, and 61 // identifies all host training loops. Host training loops are the inner 62 // most while loops that encapsulates TPUCompileOp node. This would be 63 // later used/analyzed to inroduce host loop specific optimizations such 64 // as adding sharded weight update. 65 Status DetectHostTrainingLoop( 66 const std::string* current_function_name, 67 const AttrValueMap* current_function_attr, 68 const FunctionLibraryDefinition* library, Graph* graph, 69 FunctionLibraryRuntime* flr, 70 std::vector<HostTrainingLoopInfo>* host_training_loops_info); 71 72 // Injects VariableReshardOps to before and after TPUExecute op inside 73 // host training loop body. This effectively applies sharded weight update 74 // on model weight variables. 75 Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info); 76 77 } // namespace tpu 78 } // namespace tensorflow 79 80 #endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ 81