• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16 
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/framework/types.h"
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/core/graph/graph_constructor.h"
21 #include "tensorflow/core/graph/graph_partition.h"
22 
23 namespace tensorflow {
24 
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs)25 Status PartitionFunctionGraph(
26     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
27     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
28   PartitionOptions partition_options;
29   partition_options.node_to_loc = [](const Node* node) {
30     // TODO(iga): To support the distributed case, first split the graph by
31     // worker (e.g,. using the master session's `SplitByWorker` policy), and
32     // then recursively partition the per-worker shards at the remote worker(s).
33     // Currently, we simply split the graph at device boundaries.
34     return node->assigned_device_name();
35   };
36   int64 edge_name_counter = 0;
37   partition_options.new_name = [&edge_name_counter](const string& prefix) {
38     return strings::StrCat(prefix, "/_", ++edge_name_counter);
39   };
40   partition_options.get_incarnation =
41       [&device_set](const string& name) -> int64 {
42     const Device* d = device_set.FindDeviceByName(name);
43     if (d == nullptr) {
44       return PartitionOptions::kIllegalIncarnation;
45     } else {
46       return d->attributes().incarnation();
47     }
48   };
49   partition_options.control_flow_added = false;
50   std::unordered_map<string, GraphDef> partitions;
51   TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
52 
53   for (const auto& partition : partitions) {
54     const string& device = partition.first;
55     const GraphDef& graph_def = partition.second;
56     // Each partition gets a copy of all the
57     // std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
58     std::unique_ptr<Graph> subgraph(
59         new Graph(graph->flib_def().ReachableDefinitions(graph_def)));
60     FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
61     TF_CHECK_OK(subgraph->AddFunctionLibrary(global_flib.ToProto()));
62     GraphConstructorOptions opts;
63     opts.allow_internal_ops = true;
64     opts.expect_device_spec = true;
65     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, subgraph.get()));
66     subgraphs->emplace(device, std::move(subgraph));
67   }
68 
69   return Status::OK();
70 }
71 
UpdateArgAndRetvalMetadata(Graph * subgraph,std::vector<int> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs)72 Status UpdateArgAndRetvalMetadata(
73     Graph* subgraph, std::vector<int>* arg_indices,
74     std::vector<int>* ret_indices,
75     std::vector<AllocatorAttributes>* arg_alloc_attrs,
76     std::vector<AllocatorAttributes>* ret_alloc_attrs) {
77   std::vector<std::pair<Node*, int>> arg_nodes;
78   std::vector<std::pair<Node*, int>> ret_nodes;
79   const AttrValue* attr_value;
80 
81   // Find the Arg and Retval nodes, along with their corresponding indices
82   // in the original function.
83   for (Node* node : subgraph->op_nodes()) {
84     string node_type = node->type_string();
85     if (node->IsArg()) {
86       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
87       int index = static_cast<int>(attr_value->i());
88       arg_indices->push_back(index);
89       arg_nodes.push_back(std::make_pair(node, index));
90     } else if (node->IsRetval()) {
91       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
92       int index = static_cast<int>(attr_value->i());
93       ret_indices->push_back(index);
94       ret_nodes.push_back(std::make_pair(node, index));
95     }
96   }
97 
98   for (int i = 0; i < arg_nodes.size(); ++i) {
99     Node* arg = arg_nodes[i].first;
100     arg->AddAttr("index", i);
101     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
102     AllocatorAttributes alloc_attr;
103     DataType type = attr_value->type();
104     if (MTypeFromDType(type) == HOST_MEMORY) {
105       alloc_attr.set_on_host(true);
106     }
107     arg_alloc_attrs->push_back(alloc_attr);
108   }
109   for (int i = 0; i < ret_nodes.size(); ++i) {
110     Node* ret = ret_nodes[i].first;
111     ret->AddAttr("index", i);
112     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
113     AllocatorAttributes alloc_attr;
114     DataType type = attr_value->type();
115     if (MTypeFromDType(type) == HOST_MEMORY) {
116       alloc_attr.set_on_host(true);
117     }
118     ret_alloc_attrs->push_back(alloc_attr);
119   }
120 
121   return Status::OK();
122 }
123 
GetArgsForIndices(const std::vector<int> & indices,gtl::ArraySlice<Tensor> arguments)124 std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
125                                       gtl::ArraySlice<Tensor> arguments) {
126   std::vector<Tensor> args;
127   args.reserve(indices.size());
128   for (int i : indices) {
129     args.push_back(arguments[i]);
130   }
131   return args;
132 }
133 
GetName()134 string FunctionNameGenerator::GetName() {
135   for (;; ++counter_) {
136     const string candidate = strings::StrCat(name_, "_", counter_);
137     if (flib_def_->Find(candidate) == nullptr) {
138       return candidate;
139     }
140   }
141 }
142 
143 }  // namespace tensorflow
144