1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_partition.h"
25
26 namespace tensorflow {
27
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)28 Status PartitionFunctionGraph(
29 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
30 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
31 PartitionOptions partition_options;
32 partition_options.node_to_loc = [](const Node* node) {
33 // TODO(iga): To support the distributed case, first split the graph by
34 // worker (e.g,. using the master session's `SplitByWorker` policy), and
35 // then recursively partition the per-worker shards at the remote worker(s).
36 // Currently, we simply split the graph at device boundaries.
37 return node->assigned_device_name();
38 };
39 int64 edge_name_counter = 0;
40 partition_options.new_name = [&edge_name_counter](const string& prefix) {
41 return strings::StrCat(prefix, "/_", ++edge_name_counter);
42 };
43 partition_options.get_incarnation =
44 [&device_set](const string& name) -> int64 {
45 const Device* d = device_set.FindDeviceByName(name);
46 if (d == nullptr) {
47 return PartitionOptions::kIllegalIncarnation;
48 } else {
49 return d->attributes().incarnation();
50 }
51 };
52 partition_options.control_flow_added = false;
53 std::unordered_map<string, GraphDef> partitions;
54 TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
55
56 for (auto& partition : partitions) {
57 const string& device = partition.first;
58 GraphDef& graph_def = partition.second;
59 // Each partition gets a copy of all the
60 // std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
61 std::unique_ptr<Graph> subgraph(
62 new Graph(graph->flib_def().ReachableDefinitions(graph_def)));
63 FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
64 TF_CHECK_OK(subgraph->AddFunctionLibrary(global_flib.ToProto()));
65 GraphConstructorOptions opts;
66 opts.allow_internal_ops = true;
67 opts.expect_device_spec = true;
68 TF_RETURN_IF_ERROR(
69 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
70 subgraphs->emplace(device, std::move(subgraph));
71 }
72
73 return Status::OK();
74 }
75
UpdateArgAndRetvalMetadata(Graph * graph,const string & device_type,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)76 Status UpdateArgAndRetvalMetadata(
77 Graph* graph, const string& device_type,
78 std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
79 std::vector<AllocatorAttributes>* arg_alloc_attrs,
80 std::vector<AllocatorAttributes>* ret_alloc_attrs) {
81 std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
82 std::vector<std::pair<Node*, int>> ret_nodes;
83 const AttrValue* attr_value;
84
85 // Find the Arg and Retval nodes, along with their corresponding indices
86 // in the original function.
87 for (Node* node : graph->op_nodes()) {
88 if (node->IsArg()) {
89 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
90 int index = static_cast<int>(attr_value->i());
91 int sub_index = -1;
92 if (node->attrs().Find("sub_index", &attr_value).ok()) {
93 sub_index = static_cast<int>(attr_value->i());
94 }
95 arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
96 } else if (node->IsRetval()) {
97 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
98 int index = static_cast<int>(attr_value->i());
99 ret_nodes.emplace_back(node, index);
100 }
101 }
102
103 // Sort the nodes by index so that the order is stable.
104 //
105 // In particular, this enables calling a single-partition function with
106 // the same signature as the original unpartitioned function.
107 auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
108 std::pair<Node*, FunctionArgIndex> b) {
109 return std::tie(a.second.index, a.second.sub_index) <
110 std::tie(b.second.index, b.second.sub_index);
111 };
112 std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
113 auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
114 return a.second < b.second;
115 };
116 std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
117
118 arg_indices->reserve(arg_nodes.size());
119 for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
120 ret_indices->reserve(ret_nodes.size());
121 for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
122
123 for (int i = 0; i < arg_nodes.size(); ++i) {
124 Node* arg = arg_nodes[i].first;
125 arg->AddAttr("index", i);
126 TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
127 if (arg_alloc_attrs != nullptr) {
128 AllocatorAttributes alloc_attr;
129 DataType type = attr_value->type();
130 MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
131 device_type == "XLA_GPU")
132 ? MTypeFromDTypeIntsOnDevice(type)
133 : MTypeFromDType(type);
134 if (mtype == HOST_MEMORY) {
135 alloc_attr.set_on_host(true);
136 }
137 arg_alloc_attrs->push_back(alloc_attr);
138 }
139 }
140 for (int i = 0; i < ret_nodes.size(); ++i) {
141 Node* ret = ret_nodes[i].first;
142 ret->AddAttr("index", i);
143 TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
144 if (ret_alloc_attrs) {
145 AllocatorAttributes alloc_attr;
146 DataType type = attr_value->type();
147 MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
148 device_type == "XLA_GPU")
149 ? MTypeFromDTypeIntsOnDevice(type)
150 : MTypeFromDType(type);
151 if (mtype == HOST_MEMORY) {
152 alloc_attr.set_on_host(true);
153 }
154 ret_alloc_attrs->push_back(alloc_attr);
155 }
156 }
157
158 return Status::OK();
159 }
160
GetName()161 string FunctionNameGenerator::GetName() {
162 while (true) {
163 const string candidate = strings::StrCat(name_, "_", counter_++);
164 if (flib_def_->Find(candidate) == nullptr) {
165 return candidate;
166 }
167 }
168 }
169
170 } // namespace tensorflow
171