1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/framework/types.h"
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/core/graph/graph_constructor.h"
21 #include "tensorflow/core/graph/graph_partition.h"
22
23 namespace tensorflow {
24
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)25 Status PartitionFunctionGraph(
26 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
27 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
28 PartitionOptions partition_options;
29 partition_options.node_to_loc = [](const Node* node) {
30 // TODO(iga): To support the distributed case, first split the graph by
31 // worker (e.g,. using the master session's `SplitByWorker` policy), and
32 // then recursively partition the per-worker shards at the remote worker(s).
33 // Currently, we simply split the graph at device boundaries.
34 return node->assigned_device_name();
35 };
36 int64 edge_name_counter = 0;
37 partition_options.new_name = [&edge_name_counter](const string& prefix) {
38 return strings::StrCat(prefix, "/_", ++edge_name_counter);
39 };
40 partition_options.get_incarnation =
41 [&device_set](const string& name) -> int64 {
42 const Device* d = device_set.FindDeviceByName(name);
43 if (d == nullptr) {
44 return PartitionOptions::kIllegalIncarnation;
45 } else {
46 return d->attributes().incarnation();
47 }
48 };
49 partition_options.control_flow_added = false;
50 std::unordered_map<string, GraphDef> partitions;
51 TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
52
53 for (const auto& partition : partitions) {
54 const string& device = partition.first;
55 const GraphDef& graph_def = partition.second;
56 // Each partition gets a copy of all the
57 // std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
58 std::unique_ptr<Graph> subgraph(
59 new Graph(graph->flib_def().ReachableDefinitions(graph_def)));
60 FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
61 TF_CHECK_OK(subgraph->AddFunctionLibrary(global_flib.ToProto()));
62 GraphConstructorOptions opts;
63 opts.allow_internal_ops = true;
64 opts.expect_device_spec = true;
65 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, subgraph.get()));
66 subgraphs->emplace(device, std::move(subgraph));
67 }
68
69 return Status::OK();
70 }
71
UpdateArgAndRetvalMetadata(Graph * subgraph,std::vector<int> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)72 Status UpdateArgAndRetvalMetadata(
73 Graph* subgraph, std::vector<int>* arg_indices,
74 std::vector<int>* ret_indices,
75 std::vector<AllocatorAttributes>* arg_alloc_attrs,
76 std::vector<AllocatorAttributes>* ret_alloc_attrs) {
77 std::vector<std::pair<Node*, int>> arg_nodes;
78 std::vector<std::pair<Node*, int>> ret_nodes;
79 const AttrValue* attr_value;
80
81 // Find the Arg and Retval nodes, along with their corresponding indices
82 // in the original function.
83 for (Node* node : subgraph->op_nodes()) {
84 string node_type = node->type_string();
85 if (node->IsArg()) {
86 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
87 int index = static_cast<int>(attr_value->i());
88 arg_indices->push_back(index);
89 arg_nodes.push_back(std::make_pair(node, index));
90 } else if (node->IsRetval()) {
91 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
92 int index = static_cast<int>(attr_value->i());
93 ret_indices->push_back(index);
94 ret_nodes.push_back(std::make_pair(node, index));
95 }
96 }
97
98 for (int i = 0; i < arg_nodes.size(); ++i) {
99 Node* arg = arg_nodes[i].first;
100 arg->AddAttr("index", i);
101 TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
102 AllocatorAttributes alloc_attr;
103 DataType type = attr_value->type();
104 if (MTypeFromDType(type) == HOST_MEMORY) {
105 alloc_attr.set_on_host(true);
106 }
107 arg_alloc_attrs->push_back(alloc_attr);
108 }
109 for (int i = 0; i < ret_nodes.size(); ++i) {
110 Node* ret = ret_nodes[i].first;
111 ret->AddAttr("index", i);
112 TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
113 AllocatorAttributes alloc_attr;
114 DataType type = attr_value->type();
115 if (MTypeFromDType(type) == HOST_MEMORY) {
116 alloc_attr.set_on_host(true);
117 }
118 ret_alloc_attrs->push_back(alloc_attr);
119 }
120
121 return Status::OK();
122 }
123
GetArgsForIndices(const std::vector<int> & indices,gtl::ArraySlice<Tensor> arguments)124 std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
125 gtl::ArraySlice<Tensor> arguments) {
126 std::vector<Tensor> args;
127 args.reserve(indices.size());
128 for (int i : indices) {
129 args.push_back(arguments[i]);
130 }
131 return args;
132 }
133
GetName()134 string FunctionNameGenerator::GetName() {
135 for (;; ++counter_) {
136 const string candidate = strings::StrCat(name_, "_", counter_);
137 if (flib_def_->Find(candidate) == nullptr) {
138 return candidate;
139 }
140 }
141 }
142
143 } // namespace tensorflow
144