1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ 17 #define TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ 18 19 #include <functional> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/graph/costmodel.h" 27 #include "tensorflow/core/graph/graph.h" 28 29 namespace tensorflow { 30 31 struct PartitionOptions { 32 // A function that returns a location for the execution of a given 33 // Node. 34 typedef std::function<string(const Node*)> NodeToLocFunc; 35 NodeToLocFunc node_to_loc = nullptr; 36 37 // A function that returns a unique graph node name with the given 38 // prefix. 39 typedef std::function<string(const string&)> NewNameFunc; 40 NewNameFunc new_name = nullptr; 41 42 // A function that returns the incarnation of a device given the 43 // device's fullname. If not found, GetIncarnationFunc should return 44 // kIllegalIncarnation. 45 static const uint64 kIllegalIncarnation = 0; 46 typedef std::function<uint64(const string&)> GetIncarnationFunc; 47 GetIncarnationFunc get_incarnation = nullptr; 48 49 // If specified, flib_def defines a function library that should be 50 // partitioned and replicated into each resulting partition graphs. 51 const FunctionLibraryDefinition* flib_def = nullptr; 52 53 // True if all the control flow "code" has already been added. The 54 // control flow code needs to be added when we still have the entire 55 // graph before any partitioning. So this flag should be false for 56 // the first partitioning but true for all subsequent partitioning. 57 // 58 // TODO(yuanbyu): We could also make the addition of the control 59 // flow code incremental based on 'node_to_loc'. This makes the 60 // communication a broadcast tree, which could be more efficient when 61 // the number of participating devices is large. 62 bool control_flow_added = false; 63 64 // A function that returns the data type into which the tensor 65 // should be cast before sent over the wire. 66 typedef std::function<DataType(const Edge*)> ShouldCastFunc; 67 ShouldCastFunc should_cast = nullptr; 68 69 // Schedule the execution of the recvs based on their start times 70 // computed by some scheduling algorithm. The recvs are divided into 71 // epochs based on their start times. A recv is enabled only when 72 // execution reaches its epoch - N for some predefined N. 73 bool scheduling_for_recvs = false; 74 // The start time for each node in the graph computed by some scheduling 75 // algorithm. If 'need_to_record_start_times' is true, we record them 76 // in the graph as a node attribute. 77 bool need_to_record_start_times = false; 78 std::vector<Microseconds> start_times; 79 }; 80 81 // Partition "input" graph into a set of graphs, one per location. 82 // The location for node n is derived by calling opts.node_to_loc(n). 83 // New nodes added by Partition use "opts.new_name(old_name)" to 84 // generate node names. 85 // 86 // Stores the partitions in *partitions. 87 Status Partition(const PartitionOptions& opts, Graph* input, 88 std::unordered_map<string, GraphDef>* partitions); 89 90 // Add control edges to the partitions to control the ordering 91 // and timing of the recv nodes based on the start times calculated 92 // using some scheduling algorithm. 93 Status AddControlEdges(const PartitionOptions& opts, 94 std::unordered_map<string, GraphDef>* partitions); 95 96 } // namespace tensorflow 97 98 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ 99