1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/versions.pb.h"
23 #include "tensorflow/core/grappler/clusters/cluster.h"
24 #include "tensorflow/core/grappler/devices.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29
30 namespace tensorflow {
31 namespace grappler {
32 const char kAutoParallelPrefix[] = "AutoParallel";
33
AddNodeDivConst()34 NodeDef* AutoParallel::AddNodeDivConst() {
35 NodeDef* node = graph_.add_node();
36 node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
37 node->set_op("Const");
38
39 AttrValue attr_data_type;
40 attr_data_type.set_type(DT_FLOAT);
41 node->mutable_attr()->insert({"dtype", attr_data_type});
42
43 AttrValue attr_tensor;
44 auto tensor = attr_tensor.mutable_tensor();
45 tensor->add_float_val(static_cast<float>(num_replicas_));
46 tensor->set_dtype(DT_FLOAT);
47 node->mutable_attr()->insert({"value", attr_tensor});
48 return node;
49 }
50
AddNodeDiv(const string & name,const string & input_a,const string & input_b)51 NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
52 const string& input_b) {
53 NodeDef* node = graph_.add_node();
54 node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
55 node->set_op("RealDiv");
56 node->add_input(input_a);
57 node->add_input(input_b);
58 AttrValue attr_type;
59 attr_type.set_type(DT_FLOAT);
60 node->mutable_attr()->insert({"T", attr_type});
61 return node;
62 }
63
AddNodeControl(const string & name,const std::set<string> & deps,GraphDef * graph)64 NodeDef* AutoParallel::AddNodeControl(const string& name,
65 const std::set<string>& deps,
66 GraphDef* graph) {
67 NodeDef* node = graph->add_node();
68 node->set_name(name);
69 node->set_op("NoOp");
70 for (const auto& dep : deps) {
71 node->add_input(strings::StrCat("^", dep));
72 }
73 return node;
74 }
75
Initialize(const GrapplerItem & item)76 Status AutoParallel::Initialize(const GrapplerItem& item) {
77 num_gpus_ = GetNumAvailableGPUs();
78 LOG(INFO) << "Number of GPUs: " << num_gpus_;
79 item_ = &item;
80 graph_ = item.graph;
81 LOG(INFO) << "Original graph size: " << graph_.node_size();
82 if (item.fetch.empty()) {
83 return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
84 }
85
86 if (item.MainVariables().empty()) {
87 return Status(error::INVALID_ARGUMENT, "No variables provided.");
88 }
89
90 for (const auto& init : item.init_ops) {
91 VLOG(1) << "Init node: " << init;
92 }
93
94 for (const auto& fetch : item.fetch) {
95 VLOG(1) << "Fetch node: " << fetch;
96 }
97
98 for (const auto& var : item.MainVariables()) {
99 VLOG(2) << "Variable: " << var->name();
100 }
101
102 const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
103 "ApplyProximalGradientDescent",
104 "ApplyAdadelta",
105 "ApplyAdagrad",
106 "ApplyProximalAdagrad",
107 "ApplyAdagradDA",
108 "ApplyFtrl",
109 "ApplyMomentum",
110 "ApplyAdam",
111 "ApplyRMSProp",
112 "ApplyCenteredRMSProp"};
113 for (int i = 0; i < graph_.node_size(); i++) {
114 all_nodes_.insert(
115 std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
116 if (apply_gradients_ops.find(graph_.node(i).op()) !=
117 apply_gradients_ops.end()) {
118 apply_gradients_nodes_.insert(graph_.node(i).name());
119 VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
120 }
121 }
122
123 auto div_const_node = AddNodeDivConst();
124 all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
125 std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
126 {"ApplyProximalGradientDescent", 4},
127 {"ApplyAdadelta", 6},
128 {"ApplyAdagrad", 3},
129 {"ApplyProximalAdagrad", 5},
130 {"ApplyAdagradDA", 3},
131 {"ApplyFtrl", 3},
132 {"ApplyMomentum", 3},
133 {"ApplyAdam", 9},
134 {"ApplyRMSProp", 7},
135 {"ApplyCenteredRMSProp", 8}};
136 for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
137 auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
138 auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
139
140 auto div_node = AddNodeDiv(
141 apply_gradient_node_name,
142 apply_gradients_node->input(gradient_pos[apply_gradients_op]),
143 div_const_node->name());
144 all_nodes_.insert(std::make_pair(div_node->name(), div_node));
145 *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
146 div_node->name();
147 }
148 LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
149
150 auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
151 LOG(INFO) << "Number of training nodes: " << train_nodes.size();
152
153 const NodeDef* dequeue_node;
154 for (const auto& train_node : train_nodes) {
155 if (IsDequeueOp(*train_node)) {
156 dequeue_node = train_node;
157 break;
158 }
159 }
160
161 std::vector<const NodeDef*> input_nodes;
162 if (dequeue_node) {
163 LOG(INFO) << "Dequeue node: " << dequeue_node->name();
164 input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
165 }
166 LOG(INFO) << "Number of input nodes: " << input_nodes.size();
167
168 std::set<string> dont_replicate_nodes;
169 for (const auto& variable : item.MainVariables()) {
170 dont_replicate_nodes.insert(variable->name());
171 }
172
173 for (const auto& init : item.init_ops) {
174 dont_replicate_nodes.insert(NodeName(init));
175 }
176
177 // Don't replicate all input nodes, except the dequeue node.
178 for (const auto& input_node : input_nodes) {
179 if (input_node->name() != dequeue_node->name()) {
180 dont_replicate_nodes.insert(input_node->name());
181 }
182 }
183
184 for (const auto& node : train_nodes) {
185 if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
186 replica_nodes_.insert(node->name());
187 }
188 }
189 LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
190
191 for (const auto& node : all_nodes_) {
192 if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
193 shared_nodes_.insert(node.first);
194 }
195 }
196 LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
197 return Status::OK();
198 }
199
NotSharedNode(const string & name)200 bool AutoParallel::NotSharedNode(const string& name) {
201 return shared_nodes_.find(name) == shared_nodes_.end();
202 }
203
AddSharedNodes(GraphDef * graph)204 void AutoParallel::AddSharedNodes(GraphDef* graph) {
205 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
206 for (const auto& node : shared_nodes_) {
207 auto new_node = graph->add_node();
208 *new_node = *all_nodes_[node];
209 for (int i = 0; i < new_node->input_size(); i++) {
210 if (NotSharedNode(NodeName(new_node->input(i)))) {
211 string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
212 *new_node->mutable_input(i) = new_name;
213 }
214 }
215 }
216 }
217
AddOneReplica(GraphDef * graph,int number)218 void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
219 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
220 for (const auto& node : replica_nodes_) {
221 auto new_node = graph->add_node();
222 *new_node = *all_nodes_[node];
223 if (NotSharedNode(new_node->name())) {
224 new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
225 if (num_gpus_ > 0) {
226 new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
227 }
228 for (int i = 0; i < new_node->input_size(); i++) {
229 if (NotSharedNode(NodeName(new_node->input(i)))) {
230 string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
231 *new_node->mutable_input(i) = new_name;
232 }
233 }
234 }
235 }
236 }
237
BuildGraph(GraphDef * graph)238 void AutoParallel::BuildGraph(GraphDef* graph) {
239 AddSharedNodes(graph);
240 for (int i = 0; i < num_replicas_; i++) {
241 AddOneReplica(graph, i);
242 }
243 std::set<string> fetches;
244 for (size_t i = 0; i < item_->fetch.size(); i++) {
245 for (int j = 0; j < num_replicas_; j++) {
246 string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
247 string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
248 fetches.insert(fetch);
249 }
250 }
251 string name_control =
252 strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
253 auto control = AddNodeControl(name_control, fetches, graph);
254
255 for (const auto& fetch : item_->fetch) {
256 AddNodeControl(fetch, {control->name()}, graph);
257 }
258 *graph->mutable_library() = item_->graph.library();
259 *graph->mutable_versions() = item_->graph.versions();
260 LOG(INFO) << "Parallelized graph size: " << graph->node_size();
261 }
262
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)263 Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
264 GraphDef* output) {
265 TF_RETURN_IF_ERROR(Initialize(item));
266 BuildGraph(output);
267 return Status::OK();
268 }
269
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)270 void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
271 const GraphDef& optimize_output, double result) {
272 // TODO(yaozhang): Add feedback.
273 }
274
275 } // end namespace grappler
276 } // end namespace tensorflow
277