1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/framework/memory_types.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/platform/types.h"
26
27 namespace tensorflow {
28
29 struct Endpoint {
30 int node_id;
31 int output_index;
32 };
33
34 struct EndpointHash {
operator ()tensorflow::EndpointHash35 uint32 operator()(const Endpoint& x) const {
36 return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
37 x.output_index);
38 }
39 };
40
41 struct EndpointEq {
operator ()tensorflow::EndpointEq42 uint32 operator()(const Endpoint& x, const Endpoint& y) const {
43 return (x.node_id == y.node_id) && (x.output_index == y.output_index);
44 }
45 };
46
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)47 static Status ProcessMemoryTypes(
48 const DeviceType& device_type, const Graph* g,
49 const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
50 if (device_type != DEVICE_GPU && device_type != DEVICE_SYCL) {
51 // On non-GPU and non-SYCL devices, HOST_MEMORY and DEVICE_MEMORY are always
52 // compatible.
53 return Status::OK();
54 }
55 // For GPU and SYCL device, HOST_MEMORY and DEVICE_MEMORY is not
56 // compatible. I.e., a conversion/transfer must be done.
57 //
58 // {node id, slot id} -> memory type.
59 typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
60 MemTypeMap;
61 MemTypeMap inp;
62 MemTypeMap out;
63 MemoryTypeVector inp_mvec;
64 MemoryTypeVector out_mvec;
65 for (const Node* n : g->nodes()) {
66 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
67 n->def(), &inp_mvec, &out_mvec));
68 for (size_t i = 0; i < inp_mvec.size(); ++i) {
69 VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
70 inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
71 }
72 for (size_t i = 0; i < out_mvec.size(); ++i) {
73 VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
74 out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
75 }
76 }
77 for (const Edge* e : g->edges()) {
78 if (e->IsControlEdge()) {
79 continue;
80 }
81 MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
82 DEVICE_MEMORY);
83 MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
84 DEVICE_MEMORY);
85 VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
86 << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
87 << dm;
88 TF_RETURN_IF_ERROR(fn(e, sm, dm));
89 }
90 return Status::OK();
91 }
92
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)93 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
94 return ProcessMemoryTypes(
95 device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
96 if (sm == dm) {
97 return Status::OK();
98 }
99 return errors::Internal("Memory type mismatch (", sm, " ", dm,
100 ") between :", e->src()->id(), ":",
101 e->src_output(), " and ", e->dst()->id(), ":",
102 e->dst_input(), " : from ",
103 FormatNodeForError(*e->src()), " to ",
104 FormatNodeForError(*e->dst()));
105 });
106 }
107
108 // Given an Edge whose two endpoints have different memory types and
109 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
110 // GetTensorName() returns a unique string that we can use as part of
111 // the rendezvous key. The return string is guaranteed to be unique
112 // within this process. That is sufficient because EnsureMemoryTypes
113 // is only used on a TensorFlow graph that is gonna to be executed in
114 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)115 static string GetTensorName(const Edge* edge) {
116 static std::atomic<int64> counter(0);
117 return strings::StrCat("memtype_", counter.fetch_add(1), "_",
118 edge->src()->name());
119 }
120
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)121 static Node* Send(Graph* g, const string& tensor_name,
122 const string& device_name, bool host, const Edge* edge) {
123 Node* ret;
124 TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
125 .Input(edge->src(), edge->src_output())
126 .Attr("tensor_name", tensor_name)
127 .Attr("send_device", device_name)
128 .Attr("send_device_incarnation", 0) // Do not care.
129 .Attr("recv_device", device_name)
130 .Attr("_hostmem_sendrecv", true)
131 .Finalize(g, &ret));
132 return ret;
133 }
134
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)135 static Node* Recv(Graph* g, const string& tensor_name,
136 const string& device_name, bool host, const Edge* edge) {
137 Node* ret;
138 TF_CHECK_OK(
139 NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
140 .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
141 .Attr("tensor_name", tensor_name)
142 .Attr("send_device", device_name)
143 .Attr("send_device_incarnation", 0)
144 .Attr("recv_device", device_name)
145 .Attr("_hostmem_sendrecv", true)
146 .Finalize(g, &ret));
147 return ret;
148 }
149
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)150 Status EnsureMemoryTypes(const DeviceType& device_type,
151 const string& device_name, Graph* g) {
152 struct Item {
153 const Edge* edge;
154 MemoryType sm;
155 MemoryType dm;
156 };
157 std::vector<Item> edges;
158 TF_RETURN_IF_ERROR(ProcessMemoryTypes(
159 device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
160 if (sm == dm) {
161 return Status::OK();
162 }
163 if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
164 ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
165 edges.push_back({e, sm, dm});
166 return Status::OK();
167 }
168 return errors::Internal("Unexpected memory type pair on an edge: ", sm,
169 " vs. ", dm);
170 }));
171
172 // edges contains edges in 'g' that memtype is not
173 // compatible. Therefore, if we found any, we need to insert
174 // HostSend/Recv and Send/HostRecv pairs. recv_nodes records all
175 // nodes we added so that we don't copy the same tensor more than
176 // once.
177 if (!edges.empty()) {
178 std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
179 for (const auto& item : edges) {
180 const Edge* e = item.edge;
181 const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
182 Node* recv = nullptr;
183 Endpoint key{e->src()->id(), e->src_output()};
184 auto iter = recv_nodes.find(key);
185 if (iter == recv_nodes.end()) {
186 const string tensor_name = GetTensorName(e);
187 Node* send =
188 Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
189 recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
190 if (!has_ref) {
191 // We only cache if there is no ref is involved.
192 recv_nodes[key] = recv;
193 }
194 g->AddControlEdge(send, recv);
195 } else {
196 recv = iter->second;
197 }
198 g->AddEdge(recv, 0, e->dst(), e->dst_input());
199 g->RemoveEdge(e);
200 }
201 }
202 return ValidateMemoryTypes(device_type, g);
203 }
204
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)205 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
206 const Node* n, int index, MemoryType* memory_type) {
207 MemoryTypeVector inp_mvec;
208 MemoryTypeVector out_mvec;
209 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
210 &inp_mvec, &out_mvec));
211 if (out_mvec.size() <= index) {
212 return errors::Internal("Trying to get the memory type for ", index,
213 "'th output of node ", FormatNodeForError(*n),
214 " that has only ", out_mvec.size(), " outputs");
215 }
216 *memory_type = out_mvec[index];
217 return Status::OK();
218 }
219
220 } // end namespace tensorflow
221