1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
17
18 #include <set>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/graph/algorithm.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26
27 namespace tensorflow {
28 namespace gpu_stream_util {
29
AssignStreams(const Graph * graph,const AssignStreamsOpts & opts,std::unordered_map<int,int> * node_to_stream_id)30 Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
31 std::unordered_map<int, int>* node_to_stream_id) {
32 VLOG(1) << "AssignStreams";
33 Status status;
34
35 // Sanity check arguments.
36 if (graph == nullptr)
37 status.Update(errors::InvalidArgument("Bad graph argument supplied."));
38 if (node_to_stream_id == nullptr) {
39 status.Update(
40 errors::InvalidArgument("Bad node_to_stream_id argument supplied."));
41 }
42 if ((opts.max_streams < 1) || (opts.send_stream >= opts.max_streams) ||
43 (opts.recv_stream >= opts.max_streams) ||
44 (opts.const_stream >= opts.max_streams) ||
45 (opts.compute_stream >= opts.max_streams)) {
46 status.Update(errors::InvalidArgument("Bad graph argument supplied."));
47 }
48 TF_RETURN_IF_ERROR(status);
49
50 // Topologically sort the nodes.
51 std::vector<Node*> order;
52 GetReversePostOrder(*graph, &order);
53 if (VLOG_IS_ON(2)) {
54 for (Node* n : order) {
55 const int node_id = n->id();
56 VLOG(2) << "Node " << node_id << " " << n->type_string() << " "
57 << n->name() << " " << n->in_edges().size() << " inputs";
58 for (const Edge* e : n->in_edges()) {
59 VLOG(2) << " Edge from " << e->src()->id() << " " << e->src()->name()
60 << " fanout " << e->src()->out_edges().size();
61 }
62 }
63 }
64 // We perform stream assignment assuming a large number of
65 // stream IDs and then map these down to the required number of streams
66 // using simple round-robin.
67 // Stream Assignment strategy:
68 // 1. Nodes with zero inputs are always be executed on a
69 // fresh stream.
70 // 2. Try to execute a node on the same stream as one of its
71 // inputs to avoid inter-stream dependencies.
72 // 3. If any input comes from a node with a large fanout then
73 // perhaps an indication that it is shared between parallel
74 // streams of work. We choose a new stream here so that all consumers
75 // of the tensor are likely to run in parallel.
76 int highest_stream_id = -1;
77 for (Node* n : order) {
78 VLOG(3) << "Inspecting node " << n->DebugString();
79 const int node_id = n->id();
80 const string& op = n->type_string();
81
82 // Determine a suitable stream to use.
83 int stream_id = highest_stream_id + 1;
84 for (const Edge* e : n->in_edges()) {
85 const size_t fanout = e->src()->out_edges().size();
86 if (fanout == 1) {
87 stream_id = (*node_to_stream_id)[e->src()->id()];
88 break;
89 }
90 }
91 // Override stream for specific op types.
92 if (op == "_Send") {
93 if (opts.send_stream >= 0) stream_id = opts.send_stream;
94 } else if (op == "_Recv") {
95 if (opts.recv_stream >= 0) stream_id = opts.recv_stream;
96 } else if (op == "Const") {
97 if (opts.const_stream >= 0) stream_id = opts.const_stream;
98 } else {
99 if (opts.compute_stream >= 0) stream_id = opts.compute_stream;
100 }
101
102 (*node_to_stream_id)[node_id] = stream_id % opts.max_streams;
103 highest_stream_id = std::max(stream_id, highest_stream_id);
104 }
105 VLOG(1) << "Identified " << highest_stream_id << " candidate streams for "
106 << order.size() << " nodes.";
107
108 return Status::OK();
109 }
110
111 } // namespace gpu_stream_util
112 } // namespace tensorflow
113