1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/framework/memory_types.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/dump_graph.h"
27
28 namespace tensorflow {
29
30 struct Endpoint {
31 int node_id;
32 int output_index;
33 };
34
35 struct EndpointHash {
operator ()tensorflow::EndpointHash36 uint32 operator()(const Endpoint& x) const {
37 return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
38 x.output_index);
39 }
40 };
41
42 struct EndpointEq {
operator ()tensorflow::EndpointEq43 uint32 operator()(const Endpoint& x, const Endpoint& y) const {
44 return (x.node_id == y.node_id) && (x.output_index == y.output_index);
45 }
46 };
47
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)48 static Status ProcessMemoryTypes(
49 const DeviceType& device_type, const Graph* g,
50 const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
51 if (device_type != DEVICE_GPU) {
52 // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
53 return Status::OK();
54 }
55 // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
56 // conversion/transfer must be done.
57 //
58 // {node id, slot id} -> memory type.
59 typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
60 MemTypeMap;
61 MemTypeMap inp;
62 MemTypeMap out;
63 MemoryTypeVector inp_mvec;
64 MemoryTypeVector out_mvec;
65 for (const Node* n : g->nodes()) {
66 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
67 n->def(), &inp_mvec, &out_mvec));
68 for (size_t i = 0; i < inp_mvec.size(); ++i) {
69 VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
70 inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
71 }
72 for (size_t i = 0; i < out_mvec.size(); ++i) {
73 VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
74 out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
75 }
76 }
77 for (const Edge* e : g->edges()) {
78 if (e->IsControlEdge()) {
79 continue;
80 }
81 MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
82 DEVICE_MEMORY);
83 MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
84 DEVICE_MEMORY);
85 VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
86 << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
87 << dm;
88 TF_RETURN_IF_ERROR(fn(e, sm, dm));
89 }
90 return Status::OK();
91 }
92
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)93 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
94 return ProcessMemoryTypes(
95 device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
96 if (sm == dm) {
97 return Status::OK();
98 }
99 return errors::Internal("Memory type mismatch (", sm, " ", dm,
100 ") between :", e->src()->id(), ":",
101 e->src_output(), " and ", e->dst()->id(), ":",
102 e->dst_input(), " : from ",
103 FormatNodeForError(*e->src()), " to ",
104 FormatNodeForError(*e->dst()));
105 });
106 }
107
108 // Given an Edge whose two endpoints have different memory types and
109 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
110 // GetTensorName() returns a unique string that we can use as part of
111 // the rendezvous key. The return string is guaranteed to be unique
112 // within this process. That is sufficient because EnsureMemoryTypes
113 // is only used on a TensorFlow graph that is gonna to be executed in
114 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)115 static string GetTensorName(const Edge* edge) {
116 static std::atomic<int64> counter(0);
117 return strings::StrCat("memtype_", counter.fetch_add(1), "_",
118 edge->src()->name());
119 }
120
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)121 static Node* Send(Graph* g, const string& tensor_name,
122 const string& device_name, bool host, const Edge* edge) {
123 Node* ret;
124 TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
125 .Input(edge->src(), edge->src_output())
126 .Attr("tensor_name", tensor_name)
127 .Attr("send_device", device_name)
128 .Attr("send_device_incarnation", 0) // Do not care.
129 .Attr("recv_device", device_name)
130 .Attr("_hostmem_sendrecv", true)
131 .Attr("_src", edge->src()->name())
132 .Attr("_dst", edge->dst()->name())
133 .Finalize(g, &ret));
134 return ret;
135 }
136
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)137 static Node* Recv(Graph* g, const string& tensor_name,
138 const string& device_name, bool host, const Edge* edge) {
139 Node* ret;
140 TF_CHECK_OK(
141 NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
142 .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
143 .Attr("tensor_name", tensor_name)
144 .Attr("send_device", device_name)
145 .Attr("send_device_incarnation", 0)
146 .Attr("recv_device", device_name)
147 .Attr("_hostmem_sendrecv", true)
148 .Attr("_src", edge->src()->name())
149 .Attr("_dst", edge->dst()->name())
150 .Finalize(g, &ret));
151 return ret;
152 }
153
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)154 Status EnsureMemoryTypes(const DeviceType& device_type,
155 const string& device_name, Graph* g) {
156 struct Item {
157 const Edge* edge;
158 MemoryType sm;
159 MemoryType dm;
160 };
161 std::vector<Item> edges;
162 TF_RETURN_IF_ERROR(ProcessMemoryTypes(
163 device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
164 if (sm == dm) {
165 return Status::OK();
166 }
167 if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
168 ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
169 edges.push_back({e, sm, dm});
170 return Status::OK();
171 }
172 return errors::Internal("Unexpected memory type pair on an edge: ", sm,
173 " vs. ", dm);
174 }));
175
176 // edges contains edges in 'g' that memtype is not
177 // compatible. Therefore, if we found any, we need to insert
178 // HostSend/Recv and Send/HostRecv pairs. recv_nodes records all
179 // nodes we added so that we don't copy the same tensor more than
180 // once.
181 if (!edges.empty()) {
182 std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
183 for (const auto& item : edges) {
184 const Edge* e = item.edge;
185 const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
186 Node* recv = nullptr;
187 Endpoint key{e->src()->id(), e->src_output()};
188 auto iter = recv_nodes.find(key);
189 if (iter == recv_nodes.end()) {
190 const string tensor_name = GetTensorName(e);
191 Node* send =
192 Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
193 recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
194 if (!has_ref) {
195 // We only cache if there is no ref is involved.
196 recv_nodes[key] = recv;
197 }
198 g->AddControlEdge(send, recv);
199 } else {
200 recv = iter->second;
201 }
202 g->AddEdge(recv, 0, e->dst(), e->dst_input());
203 g->RemoveEdge(e);
204 }
205 }
206
207 if (VLOG_IS_ON(2)) {
208 VLOG(2) << "Dumped graph after EnsureMemoryTypes to "
209 << DumpGraphToFile("EnsureMemoryTypes", *g);
210 }
211
212 return ValidateMemoryTypes(device_type, g);
213 }
214
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)215 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
216 const Node* n, int index, MemoryType* memory_type) {
217 MemoryTypeVector inp_mvec;
218 MemoryTypeVector out_mvec;
219 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
220 &inp_mvec, &out_mvec));
221 if (out_mvec.size() <= index) {
222 return errors::Internal("Trying to get the memory type for ", index,
223 "'th output of node ", FormatNodeForError(*n),
224 " that has only ", out_mvec.size(), " outputs");
225 }
226 *memory_type = out_mvec[index];
227 return Status::OK();
228 }
229
230 } // end namespace tensorflow
231