1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_partition.h"
25
26 namespace tensorflow {
27
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)28 Status PartitionFunctionGraph(
29 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
30 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
31 PartitionOptions partition_options;
32 partition_options.node_to_loc = [](const Node* node) {
33 // TODO(iga): To support the distributed case, first split the graph by
34 // worker (e.g,. using the master session's `SplitByWorker` policy), and
35 // then recursively partition the per-worker shards at the remote worker(s).
36 // Currently, we simply split the graph at device boundaries.
37 return node->assigned_device_name();
38 };
39 int64_t edge_name_counter = 0;
40 partition_options.new_name = [&edge_name_counter](const string& prefix) {
41 return strings::StrCat(prefix, "/_", ++edge_name_counter);
42 };
43 partition_options.get_incarnation =
44 [&device_set](const string& name) -> int64 {
45 const Device* d = device_set.FindDeviceByName(name);
46 if (d == nullptr) {
47 return PartitionOptions::kIllegalIncarnation;
48 } else {
49 return d->attributes().incarnation();
50 }
51 };
52 partition_options.control_flow_added = false;
53 std::unordered_map<string, GraphDef> partitions;
54 TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
55
56 for (auto& partition : partitions) {
57 const string& device = partition.first;
58 GraphDef& graph_def = partition.second;
59 // Each partition gets a new graph.
60 std::unique_ptr<Graph> subgraph(
61 new Graph(graph->flib_def().default_registry()));
62 GraphConstructorOptions opts;
63 opts.allow_internal_ops = true;
64 opts.expect_device_spec = true;
65 TF_RETURN_IF_ERROR(
66 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
67 subgraphs->emplace(device, std::move(subgraph));
68 }
69
70 return Status::OK();
71 }
72
UpdateArgAndRetvalMetadata(Graph * graph,const string & device_type,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)73 Status UpdateArgAndRetvalMetadata(
74 Graph* graph, const string& device_type,
75 std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
76 std::vector<AllocatorAttributes>* arg_alloc_attrs,
77 std::vector<AllocatorAttributes>* ret_alloc_attrs) {
78 std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
79 std::vector<std::pair<Node*, int>> ret_nodes;
80 const AttrValue* attr_value;
81
82 // Find the Arg and Retval nodes, along with their corresponding indices
83 // in the original function.
84 for (Node* node : graph->op_nodes()) {
85 if (node->IsArg()) {
86 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
87 int index = static_cast<int>(attr_value->i());
88 int sub_index = -1;
89 if (node->attrs().Find("sub_index", &attr_value).ok()) {
90 sub_index = static_cast<int>(attr_value->i());
91 }
92 arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
93 } else if (node->IsRetval()) {
94 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
95 int index = static_cast<int>(attr_value->i());
96 ret_nodes.emplace_back(node, index);
97 }
98 }
99
100 // Sort the nodes by index so that the order is stable.
101 //
102 // In particular, this enables calling a single-partition function with
103 // the same signature as the original unpartitioned function.
104 auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
105 std::pair<Node*, FunctionArgIndex> b) {
106 return std::tie(a.second.index, a.second.sub_index) <
107 std::tie(b.second.index, b.second.sub_index);
108 };
109 std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
110 auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
111 return a.second < b.second;
112 };
113 std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
114
115 arg_indices->reserve(arg_nodes.size());
116 for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
117 ret_indices->reserve(ret_nodes.size());
118 for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
119
120 for (int i = 0; i < arg_nodes.size(); ++i) {
121 Node* arg = arg_nodes[i].first;
122 arg->AddAttr("index", i);
123 TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
124 if (arg_alloc_attrs != nullptr) {
125 AllocatorAttributes alloc_attr;
126 DataType type = attr_value->type();
127 MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
128 device_type == "XLA_GPU")
129 ? MTypeFromDTypeIntsOnDevice(type)
130 : MTypeFromDType(type);
131 if (mtype == HOST_MEMORY) {
132 alloc_attr.set_on_host(true);
133 }
134 arg_alloc_attrs->push_back(alloc_attr);
135 }
136 }
137 for (int i = 0; i < ret_nodes.size(); ++i) {
138 Node* ret = ret_nodes[i].first;
139 ret->AddAttr("index", i);
140 TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
141 if (ret_alloc_attrs) {
142 AllocatorAttributes alloc_attr;
143 DataType type = attr_value->type();
144 MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
145 device_type == "XLA_GPU")
146 ? MTypeFromDTypeIntsOnDevice(type)
147 : MTypeFromDType(type);
148 if (mtype == HOST_MEMORY) {
149 alloc_attr.set_on_host(true);
150 }
151 ret_alloc_attrs->push_back(alloc_attr);
152 }
153 }
154
155 return Status::OK();
156 }
157
GetName()158 string FunctionNameGenerator::GetName() {
159 while (true) {
160 const string candidate = strings::StrCat(name_, "_", counter_++);
161 if (flib_def_->Find(candidate) == nullptr) {
162 return candidate;
163 }
164 }
165 }
166
167 } // namespace tensorflow
168