• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/graph/collective_order.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/graph/algorithm.h"
20 
21 namespace tensorflow {
22 namespace {
23 
24 // Find all CollectiveReduce nodes and the existing data dependencies between
25 // them.
DiscoverDataDependencies(const Graph * graph,std::vector<Node * > * collective_nodes,std::vector<int32> * instance_keys,absl::flat_hash_map<Node *,absl::flat_hash_set<int32>> * data_dependencies)26 Status DiscoverDataDependencies(
27     const Graph* graph, std::vector<Node*>* collective_nodes,
28     std::vector<int32>* instance_keys,
29     absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies) {
30   Status s;
31   // Algorithm: do Reverse DFS starting at sink.  `node_leave` is called when
32   // all parents of `node` have been visited.  At that point,
33   // `data_dependencies[node]` is a list containing `instance_key` of every
34   // `CollectiveReduce` on which `node` has a data dependency.
35   // For this node's children, add all these instance keys.  Also, if this node
36   // is collective, add as a dependency for the children.
37   auto node_leave = [collective_nodes, instance_keys, data_dependencies,
38                      &s](Node* node) {
39     int32 instance_key;
40     bool enter_node =
41         node->IsCollective() && node->type_string() == "CollectiveReduce";
42     if (enter_node) {
43       Status get_attr_status =
44           GetNodeAttr(node->attrs(), "instance_key", &instance_key);
45       s.Update(get_attr_status);
46       collective_nodes->push_back(node);
47       instance_keys->push_back(instance_key);
48       VLOG(2) << "collective node " << node->DebugString();
49     }
50     // Avoid reference invalidation of `node_deps`.
51     data_dependencies->reserve(data_dependencies->size() + 1 +
52                                node->out_edges().size());
53     const auto& node_deps = (*data_dependencies)[node];
54     for (const Edge* out_edge : node->out_edges()) {
55       auto& child_deps = (*data_dependencies)[out_edge->dst()];
56       child_deps.insert(node_deps.begin(), node_deps.end());
57       if (enter_node && s.ok()) {
58         child_deps.insert(instance_key);
59       }
60     }
61   };
62   ReverseDFS(*graph, nullptr, node_leave);
63   return s;
64 }
65 
66 // Given a list of `collective_nodes` and `data_dependencies` between the
67 // collective nodes, create control dependencies between concurrent collectives
68 // and store in `dependency_edges`.
69 // If there exists an edge a -> b then `dependency_edges[a]` contains `b`
CreateControlDependencies(const std::vector<Node * > & collective_nodes,const std::vector<int32> & instance_keys,absl::flat_hash_map<Node *,absl::flat_hash_set<int32>> * data_dependencies,absl::flat_hash_map<Node *,absl::flat_hash_set<Node * >> * dependency_edges)70 Status CreateControlDependencies(
71     const std::vector<Node*>& collective_nodes,
72     const std::vector<int32>& instance_keys,
73     absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies,
74     absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>* dependency_edges) {
75   // If there exists some path a -> ... -> b then `all_paths[a]` contains `b`
76   absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> all_paths;
77   for (int i = 0; i < collective_nodes.size() - 1; i++) {
78     if (!collective_nodes[i]->IsCollective() ||
79         collective_nodes[i]->type_string() != "CollectiveReduce") {
80       return errors::Internal("Unexpected node ",
81                               collective_nodes[i]->DebugString());
82     }
83     const auto& deps_i = (*data_dependencies)[collective_nodes[i]];
84     for (int j = i + 1; j < collective_nodes.size(); j++) {
85       if (collective_nodes[i]->requested_device() !=
86           collective_nodes[j]->requested_device()) {
87         continue;
88       }
89       if (instance_keys[i] == instance_keys[j]) {
90         return errors::Internal("Unexpected same instance_key ",
91                                 instance_keys[i],
92                                 " on 2 nodes with the same device ",
93                                 collective_nodes[i]->requested_device());
94       }
95       const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
96       if (deps_i.find(instance_keys[j]) == deps_i.end() &&
97           deps_j.find(instance_keys[i]) == deps_j.end()) {
98         int src_idx = instance_keys[i] > instance_keys[j] ? i : j;
99         int dst_idx = instance_keys[i] > instance_keys[j] ? j : i;
100         Node* src_node = collective_nodes[src_idx];
101         Node* dst_node = collective_nodes[dst_idx];
102         VLOG(1) << "Adding control dependency from node " << src_node->name()
103                 << " instance " << instance_keys[src_idx] << " to node "
104                 << dst_node->name() << " instance " << instance_keys[dst_idx];
105         (*dependency_edges)[src_node].insert(dst_node);
106         auto& src_paths = all_paths[src_node];
107         src_paths.insert(dst_node);
108         for (Node* downstream_node : all_paths[dst_node]) {
109           src_paths.insert(downstream_node);
110         }
111       }
112     }
113   }
114 
115   // Prune dependency edges so that if there are edges a -> b, b -> c, and a ->
116   // c, then remove a -> c.  This dependency would be handled naturally during
117   // op scheduling.
118   for (int i = 0; i < collective_nodes.size(); ++i) {
119     Node* node = collective_nodes[i];
120     auto& neighbor_set = (*dependency_edges)[node];
121     std::vector<Node*> neighbor_list(neighbor_set.begin(), neighbor_set.end());
122     // For all n1, n2 in `neighbor_list` if there is a path from n1 -> n2 then
123     // eliminate n2 from `neighbor_set` and `neighbor_list`.  We remove from
124     // `neighbor_list` by replacing with a `nullptr`, hence the `nullptr` checks
125     // below.
126     for (int j = 0; j < neighbor_list.size(); ++j) {
127       Node* n1 = neighbor_list[j];
128       if (n1 == nullptr) continue;
129       auto& n1_paths = all_paths[n1];
130       for (int k = 0; k < neighbor_list.size(); ++k) {
131         Node* n2 = neighbor_list[k];
132         if (j == k || n2 == nullptr) continue;
133         if (n1_paths.find(n2) != n1_paths.end()) {
134           neighbor_set.erase(n2);
135           neighbor_list[k] = nullptr;
136         }
137       }
138     }
139   }
140 
141   return Status::OK();
142 }
143 
144 // Insert control dependencies defined by `dependency_edges` in `graph`.  If
145 // `order_type` is `kEdges`, insert explicit control edges, else if `order_type`
146 // is `kAttrs`, encode dependencies as an attribute on collective node.
InsertControlDependencies(Graph * graph,GraphCollectiveOrder order_type,const absl::flat_hash_map<Node *,absl::flat_hash_set<Node * >> & dependency_edges)147 Status InsertControlDependencies(
148     Graph* graph, GraphCollectiveOrder order_type,
149     const absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>&
150         dependency_edges) {
151   if (order_type == GraphCollectiveOrder::kEdges) {
152     for (const auto& pair : dependency_edges) {
153       Node* src_node = pair.first;
154       for (Node* dst_node : pair.second) {
155         graph->AddControlEdge(src_node, dst_node);
156       }
157     }
158   } else if (order_type == GraphCollectiveOrder::kAttrs) {
159     // `wait_for` is the inverse of `dependency_edges`, i.e. `wait_for[node]`
160     // contains the list of instance keys for which `node` must wait.
161     absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> wait_for;
162     for (const auto& pair : dependency_edges) {
163       int32 src_instance;
164       TF_RETURN_IF_ERROR(
165           GetNodeAttr(pair.first->attrs(), "instance_key", &src_instance));
166       for (Node* dst_node : pair.second) {
167         wait_for[dst_node].insert(src_instance);
168       }
169     }
170     for (const auto& pair : wait_for) {
171       std::vector<int32> wait_for_list(pair.second.begin(), pair.second.end());
172       pair.first->ClearAttr("wait_for");
173       pair.first->AddAttr("wait_for", wait_for_list);
174     }
175   } else {
176     return errors::Internal("Unexpected GraphCollectiveOrder type ",
177                             static_cast<int>(order_type));
178   }
179   return Status::OK();
180 }
181 
182 }  // namespace
183 
OrderCollectives(Graph * graph,GraphCollectiveOrder order_type)184 Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) {
185   // `instance_keys[i]` corresponds to `collective_nodes[i]`
186   std::vector<Node*> collective_nodes;
187   std::vector<int32> instance_keys;
188   // node -> set of collectives on which node depends.
189   absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> data_dependencies;
190   TF_RETURN_IF_ERROR(DiscoverDataDependencies(
191       graph, &collective_nodes, &instance_keys, &data_dependencies));
192 
193   if (collective_nodes.empty()) return Status::OK();
194 
195   absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> dependency_edges;
196   // For all pairs of collective nodes n1 and n2 on the same device, if n1 does
197   // not depend on n2 and n2 does not depend on n1, then they are potentially
198   // concurrent.  Create an arbitrary, deterministic ordering between them.
199   TF_RETURN_IF_ERROR(CreateControlDependencies(
200       collective_nodes, instance_keys, &data_dependencies, &dependency_edges));
201 
202   return InsertControlDependencies(graph, order_type, dependency_edges);
203 }
204 
205 }  // namespace tensorflow
206