• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_partition.h"
25 
26 namespace tensorflow {
27 
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)28 Status PartitionFunctionGraph(
29     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
30     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
31   PartitionOptions partition_options;
32   partition_options.node_to_loc = [](const Node* node) {
33     // TODO(iga): To support the distributed case, first split the graph by
34     // worker (e.g,. using the master session's `SplitByWorker` policy), and
35     // then recursively partition the per-worker shards at the remote worker(s).
36     // Currently, we simply split the graph at device boundaries.
37     return node->assigned_device_name();
38   };
39   int64 edge_name_counter = 0;
40   partition_options.new_name = [&edge_name_counter](const string& prefix) {
41     return strings::StrCat(prefix, "/_", ++edge_name_counter);
42   };
43   partition_options.get_incarnation =
44       [&device_set](const string& name) -> int64 {
45     const Device* d = device_set.FindDeviceByName(name);
46     if (d == nullptr) {
47       return PartitionOptions::kIllegalIncarnation;
48     } else {
49       return d->attributes().incarnation();
50     }
51   };
52   partition_options.control_flow_added = false;
53   std::unordered_map<string, GraphDef> partitions;
54   TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
55 
56   for (auto& partition : partitions) {
57     const string& device = partition.first;
58     GraphDef& graph_def = partition.second;
59     // Each partition gets a copy of all the
60     // std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
61     std::unique_ptr<Graph> subgraph(
62         new Graph(graph->flib_def().ReachableDefinitions(graph_def)));
63     FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
64     TF_CHECK_OK(subgraph->AddFunctionLibrary(global_flib.ToProto()));
65     GraphConstructorOptions opts;
66     opts.allow_internal_ops = true;
67     opts.expect_device_spec = true;
68     TF_RETURN_IF_ERROR(
69         ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
70     subgraphs->emplace(device, std::move(subgraph));
71   }
72 
73   return Status::OK();
74 }
75 
UpdateArgAndRetvalMetadata(Graph * graph,const string & device_type,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)76 Status UpdateArgAndRetvalMetadata(
77     Graph* graph, const string& device_type,
78     std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
79     std::vector<AllocatorAttributes>* arg_alloc_attrs,
80     std::vector<AllocatorAttributes>* ret_alloc_attrs) {
81   std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
82   std::vector<std::pair<Node*, int>> ret_nodes;
83   const AttrValue* attr_value;
84 
85   // Find the Arg and Retval nodes, along with their corresponding indices
86   // in the original function.
87   for (Node* node : graph->op_nodes()) {
88     if (node->IsArg()) {
89       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
90       int index = static_cast<int>(attr_value->i());
91       int sub_index = -1;
92       if (node->attrs().Find("sub_index", &attr_value).ok()) {
93         sub_index = static_cast<int>(attr_value->i());
94       }
95       arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
96     } else if (node->IsRetval()) {
97       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
98       int index = static_cast<int>(attr_value->i());
99       ret_nodes.emplace_back(node, index);
100     }
101   }
102 
103   // Sort the nodes by index so that the order is stable.
104   //
105   // In particular, this enables calling a single-partition function with
106   // the same signature as the original unpartitioned function.
107   auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
108                            std::pair<Node*, FunctionArgIndex> b) {
109     return std::tie(a.second.index, a.second.sub_index) <
110            std::tie(b.second.index, b.second.sub_index);
111   };
112   std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
113   auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
114     return a.second < b.second;
115   };
116   std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
117 
118   arg_indices->reserve(arg_nodes.size());
119   for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
120   ret_indices->reserve(ret_nodes.size());
121   for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
122 
123   for (int i = 0; i < arg_nodes.size(); ++i) {
124     Node* arg = arg_nodes[i].first;
125     arg->AddAttr("index", i);
126     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
127     if (arg_alloc_attrs != nullptr) {
128       AllocatorAttributes alloc_attr;
129       DataType type = attr_value->type();
130       MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
131                           device_type == "XLA_GPU")
132                              ? MTypeFromDTypeIntsOnDevice(type)
133                              : MTypeFromDType(type);
134       if (mtype == HOST_MEMORY) {
135         alloc_attr.set_on_host(true);
136       }
137       arg_alloc_attrs->push_back(alloc_attr);
138     }
139   }
140   for (int i = 0; i < ret_nodes.size(); ++i) {
141     Node* ret = ret_nodes[i].first;
142     ret->AddAttr("index", i);
143     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
144     if (ret_alloc_attrs) {
145       AllocatorAttributes alloc_attr;
146       DataType type = attr_value->type();
147       MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
148                           device_type == "XLA_GPU")
149                              ? MTypeFromDTypeIntsOnDevice(type)
150                              : MTypeFromDType(type);
151       if (mtype == HOST_MEMORY) {
152         alloc_attr.set_on_host(true);
153       }
154       ret_alloc_attrs->push_back(alloc_attr);
155     }
156   }
157 
158   return Status::OK();
159 }
160 
GetName()161 string FunctionNameGenerator::GetName() {
162   while (true) {
163     const string candidate = strings::StrCat(name_, "_", counter_++);
164     if (flib_def_->Find(candidate) == nullptr) {
165       return candidate;
166     }
167   }
168 }
169 
170 }  // namespace tensorflow
171