• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/function.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_partition.h"
25 
26 namespace tensorflow {
27 
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)28 Status PartitionFunctionGraph(
29     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
30     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
31   PartitionOptions partition_options;
32   partition_options.node_to_loc = [](const Node* node) {
33     // TODO(iga): To support the distributed case, first split the graph by
34     // worker (e.g,. using the master session's `SplitByWorker` policy), and
35     // then recursively partition the per-worker shards at the remote worker(s).
36     // Currently, we simply split the graph at device boundaries.
37     return node->assigned_device_name();
38   };
39   int64_t edge_name_counter = 0;
40   partition_options.new_name = [&edge_name_counter](const string& prefix) {
41     return strings::StrCat(prefix, "/_", ++edge_name_counter);
42   };
43   partition_options.get_incarnation =
44       [&device_set](const string& name) -> int64 {
45     const Device* d = device_set.FindDeviceByName(name);
46     if (d == nullptr) {
47       return PartitionOptions::kIllegalIncarnation;
48     } else {
49       return d->attributes().incarnation();
50     }
51   };
52   partition_options.control_flow_added = false;
53   std::unordered_map<string, GraphDef> partitions;
54   TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
55 
56   for (auto& partition : partitions) {
57     const string& device = partition.first;
58     GraphDef& graph_def = partition.second;
59     // Each partition gets a new graph.
60     std::unique_ptr<Graph> subgraph(
61         new Graph(graph->flib_def().default_registry()));
62     GraphConstructorOptions opts;
63     opts.allow_internal_ops = true;
64     opts.expect_device_spec = true;
65     TF_RETURN_IF_ERROR(
66         ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
67     subgraphs->emplace(device, std::move(subgraph));
68   }
69 
70   return Status::OK();
71 }
72 
UpdateArgAndRetvalMetadata(Graph * graph,const string & device_type,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)73 Status UpdateArgAndRetvalMetadata(
74     Graph* graph, const string& device_type,
75     std::vector<FunctionArgIndex>* arg_indices, std::vector<int>* ret_indices,
76     std::vector<AllocatorAttributes>* arg_alloc_attrs,
77     std::vector<AllocatorAttributes>* ret_alloc_attrs) {
78   std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
79   std::vector<std::pair<Node*, int>> ret_nodes;
80   const AttrValue* attr_value;
81 
82   // Find the Arg and Retval nodes, along with their corresponding indices
83   // in the original function.
84   for (Node* node : graph->op_nodes()) {
85     if (node->IsArg()) {
86       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
87       int index = static_cast<int>(attr_value->i());
88       int sub_index = -1;
89       if (node->attrs().Find("sub_index", &attr_value).ok()) {
90         sub_index = static_cast<int>(attr_value->i());
91       }
92       arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
93     } else if (node->IsRetval()) {
94       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
95       int index = static_cast<int>(attr_value->i());
96       ret_nodes.emplace_back(node, index);
97     }
98   }
99 
100   // Sort the nodes by index so that the order is stable.
101   //
102   // In particular, this enables calling a single-partition function with
103   // the same signature as the original unpartitioned function.
104   auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
105                            std::pair<Node*, FunctionArgIndex> b) {
106     return std::tie(a.second.index, a.second.sub_index) <
107            std::tie(b.second.index, b.second.sub_index);
108   };
109   std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
110   auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
111     return a.second < b.second;
112   };
113   std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
114 
115   arg_indices->reserve(arg_nodes.size());
116   for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
117   ret_indices->reserve(ret_nodes.size());
118   for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
119 
120   for (int i = 0; i < arg_nodes.size(); ++i) {
121     Node* arg = arg_nodes[i].first;
122     arg->AddAttr("index", i);
123     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
124     if (arg_alloc_attrs != nullptr) {
125       AllocatorAttributes alloc_attr;
126       DataType type = attr_value->type();
127       MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
128                           device_type == "XLA_GPU")
129                              ? MTypeFromDTypeIntsOnDevice(type)
130                              : MTypeFromDType(type);
131       if (mtype == HOST_MEMORY) {
132         alloc_attr.set_on_host(true);
133       }
134       arg_alloc_attrs->push_back(alloc_attr);
135     }
136   }
137   for (int i = 0; i < ret_nodes.size(); ++i) {
138     Node* ret = ret_nodes[i].first;
139     ret->AddAttr("index", i);
140     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
141     if (ret_alloc_attrs) {
142       AllocatorAttributes alloc_attr;
143       DataType type = attr_value->type();
144       MemoryType mtype = (device_type == "TPU" || device_type == "XLA_CPU" ||
145                           device_type == "XLA_GPU")
146                              ? MTypeFromDTypeIntsOnDevice(type)
147                              : MTypeFromDType(type);
148       if (mtype == HOST_MEMORY) {
149         alloc_attr.set_on_host(true);
150       }
151       ret_alloc_attrs->push_back(alloc_attr);
152     }
153   }
154 
155   return Status::OK();
156 }
157 
GetName()158 string FunctionNameGenerator::GetName() {
159   while (true) {
160     const string candidate = strings::StrCat(name_, "_", counter_++);
161     if (flib_def_->Find(candidate) == nullptr) {
162       return candidate;
163     }
164   }
165 }
166 
167 }  // namespace tensorflow
168