• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/framework/memory_types.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/dump_graph.h"
27 
28 namespace tensorflow {
29 
30 struct Endpoint {
31   int node_id;
32   int output_index;
33 };
34 
35 struct EndpointHash {
operator ()tensorflow::EndpointHash36   uint32 operator()(const Endpoint& x) const {
37     return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
38                   x.output_index);
39   }
40 };
41 
42 struct EndpointEq {
operator ()tensorflow::EndpointEq43   uint32 operator()(const Endpoint& x, const Endpoint& y) const {
44     return (x.node_id == y.node_id) && (x.output_index == y.output_index);
45   }
46 };
47 
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)48 static Status ProcessMemoryTypes(
49     const DeviceType& device_type, const Graph* g,
50     const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
51   if (device_type != DEVICE_GPU) {
52     // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
53     return Status::OK();
54   }
55   // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
56   // conversion/transfer must be done.
57   //
58   // {node id, slot id} -> memory type.
59   typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
60       MemTypeMap;
61   MemTypeMap inp;
62   MemTypeMap out;
63   MemoryTypeVector inp_mvec;
64   MemoryTypeVector out_mvec;
65   for (const Node* n : g->nodes()) {
66     TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
67                                           n->def(), &inp_mvec, &out_mvec));
68     for (size_t i = 0; i < inp_mvec.size(); ++i) {
69       VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
70       inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
71     }
72     for (size_t i = 0; i < out_mvec.size(); ++i) {
73       VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
74       out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
75     }
76   }
77   for (const Edge* e : g->edges()) {
78     if (e->IsControlEdge()) {
79       continue;
80     }
81     MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
82                                          DEVICE_MEMORY);
83     MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
84                                          DEVICE_MEMORY);
85     VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
86             << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
87             << dm;
88     TF_RETURN_IF_ERROR(fn(e, sm, dm));
89   }
90   return Status::OK();
91 }
92 
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)93 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
94   return ProcessMemoryTypes(
95       device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
96         if (sm == dm) {
97           return Status::OK();
98         }
99         return errors::Internal("Memory type mismatch (", sm, " ", dm,
100                                 ") between :", e->src()->id(), ":",
101                                 e->src_output(), " and ", e->dst()->id(), ":",
102                                 e->dst_input(), " : from ",
103                                 FormatNodeForError(*e->src()), " to ",
104                                 FormatNodeForError(*e->dst()));
105       });
106 }
107 
108 // Given an Edge whose two endpoints have different memory types and
109 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
110 // GetTensorName() returns a unique string that we can use as part of
111 // the rendezvous key. The return string is guaranteed to be unique
112 // within this process. That is sufficient because EnsureMemoryTypes
113 // is only used on a TensorFlow graph that is gonna to be executed in
114 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)115 static string GetTensorName(const Edge* edge) {
116   static std::atomic<int64> counter(0);
117   return strings::StrCat("memtype_", counter.fetch_add(1), "_",
118                          edge->src()->name());
119 }
120 
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)121 static Node* Send(Graph* g, const string& tensor_name,
122                   const string& device_name, bool host, const Edge* edge) {
123   Node* ret;
124   TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
125                   .Input(edge->src(), edge->src_output())
126                   .Attr("tensor_name", tensor_name)
127                   .Attr("send_device", device_name)
128                   .Attr("send_device_incarnation", 0)  // Do not care.
129                   .Attr("recv_device", device_name)
130                   .Attr("_hostmem_sendrecv", true)
131                   .Attr("_src", edge->src()->name())
132                   .Attr("_dst", edge->dst()->name())
133                   .Finalize(g, &ret));
134   return ret;
135 }
136 
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)137 static Node* Recv(Graph* g, const string& tensor_name,
138                   const string& device_name, bool host, const Edge* edge) {
139   Node* ret;
140   TF_CHECK_OK(
141       NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
142           .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
143           .Attr("tensor_name", tensor_name)
144           .Attr("send_device", device_name)
145           .Attr("send_device_incarnation", 0)
146           .Attr("recv_device", device_name)
147           .Attr("_hostmem_sendrecv", true)
148           .Attr("_src", edge->src()->name())
149           .Attr("_dst", edge->dst()->name())
150           .Finalize(g, &ret));
151   return ret;
152 }
153 
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)154 Status EnsureMemoryTypes(const DeviceType& device_type,
155                          const string& device_name, Graph* g) {
156   struct Item {
157     const Edge* edge;
158     MemoryType sm;
159     MemoryType dm;
160   };
161   std::vector<Item> edges;
162   TF_RETURN_IF_ERROR(ProcessMemoryTypes(
163       device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
164         if (sm == dm) {
165           return Status::OK();
166         }
167         if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
168             ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
169           edges.push_back({e, sm, dm});
170           return Status::OK();
171         }
172         return errors::Internal("Unexpected memory type pair on an edge: ", sm,
173                                 " vs. ", dm);
174       }));
175 
176   // edges contains edges in 'g' that memtype is not
177   // compatible. Therefore, if we found any, we need to insert
178   // HostSend/Recv and Send/HostRecv pairs.  recv_nodes records all
179   // nodes we added so that we don't copy the same tensor more than
180   // once.
181   if (!edges.empty()) {
182     std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
183     for (const auto& item : edges) {
184       const Edge* e = item.edge;
185       const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
186       Node* recv = nullptr;
187       Endpoint key{e->src()->id(), e->src_output()};
188       auto iter = recv_nodes.find(key);
189       if (iter == recv_nodes.end()) {
190         const string tensor_name = GetTensorName(e);
191         Node* send =
192             Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
193         recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
194         if (!has_ref) {
195           // We only cache if there is no ref is involved.
196           recv_nodes[key] = recv;
197         }
198         g->AddControlEdge(send, recv);
199       } else {
200         recv = iter->second;
201       }
202       g->AddEdge(recv, 0, e->dst(), e->dst_input());
203       g->RemoveEdge(e);
204     }
205   }
206 
207   if (VLOG_IS_ON(2)) {
208     VLOG(2) << "Dumped graph after EnsureMemoryTypes to "
209             << DumpGraphToFile("EnsureMemoryTypes", *g);
210   }
211 
212   return ValidateMemoryTypes(device_type, g);
213 }
214 
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)215 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
216                            const Node* n, int index, MemoryType* memory_type) {
217   MemoryTypeVector inp_mvec;
218   MemoryTypeVector out_mvec;
219   TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
220                                         &inp_mvec, &out_mvec));
221   if (out_mvec.size() <= index) {
222     return errors::Internal("Trying to get the memory type for ", index,
223                             "'th output of node ", FormatNodeForError(*n),
224                             " that has only ", out_mvec.size(), " outputs");
225   }
226   *memory_type = out_mvec[index];
227   return Status::OK();
228 }
229 
230 }  // end namespace tensorflow
231