• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/lib/core/error_codes.pb.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace internal {
33 
34 // TODO(williamchan): Change this constant to be something smarter, maybe
35 // dynamically determined.
36 constexpr int64 kTensorMaxSize = 64;
37 
38 // All the nodes that should be blacklisted and not swapped.
IsBlacklisted(const NodeDef & node)39 bool IsBlacklisted(const NodeDef& node) {
40   return
41       // Collective ops should not be swapped.
42       IsCollective(node) ||
43       // ControlFlow ops should not be swapped.
44       IsControlFlow(node) ||
45       // NoOp ops should not be swapped (due to group dependencies).
46       IsNoOp(node);
47 }
48 
49 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)50 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
51   if (prop.dtype() == DataType::DT_STRING) {
52     return true;
53   }
54 
55   // Check type to be int32 or int64.
56   if (prop.dtype() != DataType::DT_INT32 &&
57       prop.dtype() != DataType::DT_INT64) {
58     return false;
59   }
60 
61   // Check size known and small.
62   const int64 size = NumCoefficients(prop.shape());
63   if (size < 0 || size > kTensorMaxSize) {
64     return false;
65   }
66 
67   return true;
68 }
69 
70 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)71 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
72                         const NodeDef& node, const KernelDef** kdef) {
73   for (const DeviceType& device : devices) {
74     const KernelDef* kernel = nullptr;
75     Status s = FindKernelDef(device, node, &kernel, nullptr);
76     if (s.ok()) {
77       if (kdef) {
78         *kdef = kernel;
79       }
80       return Status::OK();
81     }
82   }
83 
84   return errors::NotFound("Could not find KernelDef for op: ", node.op());
85 }
86 
87 // Checks if a node's output port is host friendly.
88 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)89 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
90                                     GraphProperties* properties,
91                                     const NodeDef& node, int port_id,
92                                     bool* is_candidate) {
93   *is_candidate = false;
94 
95   // Make sure we are not a blacklisted op.
96   if (IsBlacklisted(node)) {
97     return Status::OK();
98   }
99 
100   // Check to make sure we have the right properties (i.e., statically shaped).
101   if (!properties->has_properties()) {
102     // This is an expensive call, call it lazily.
103     TF_RETURN_IF_ERROR(properties->InferStatically(
104         /*assume_valid_feeds=*/false));
105   }
106   const auto& output_properties = properties->GetOutputProperties(node.name());
107   if (port_id >= output_properties.size()) {
108     LOG(WARNING) << "port_id=" << port_id
109                  << " but output_properties.size()=" << output_properties.size()
110                  << "\n"
111                  << node.DebugString();
112     return Status::OK();
113   }
114   if (!IsTensorSmall(output_properties[port_id])) {
115     return Status::OK();
116   }
117 
118   // These nodes may be optimized away downstream (even if pinned to Host), we
119   // should (recusively) check their source.
120   if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
121     for (const auto& fanin : graph.GetFanins(node, false)) {
122       bool fanin_candidate = false;
123       TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
124           graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
125       if (!fanin_candidate) {
126         return Status::OK();
127       }
128     }
129     *is_candidate = true;
130     return Status::OK();
131   }
132 
133   // Check if op's device is on CPU.
134   if (str_util::StrContains(node.device(), DEVICE_CPU)) {
135     *is_candidate = true;
136     return Status::OK();
137   }
138 
139   // Check if op's output port is pinned to HostMemory.
140   const OpDef* op = nullptr;
141   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
142   if (!s.ok()) {
143     LOG(WARNING) << "Could not find OpDef for : " << node.op();
144     return Status::OK();
145   }
146 
147   // Map the port_id to output_arg_id.
148   const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
149   if (output_arg_id < 0) {
150     LOG(WARNING) << "Invalid port: " << port_id << "!\n"
151                  << node.DebugString() << "\n"
152                  << op->DebugString();
153     return Status::OK();
154   }
155 
156   // Find the kernel.
157   const KernelDef* kernel = nullptr;
158   s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
159                        &kernel);
160   if (!s.ok()) {
161     LOG(INFO) << "Could not find KernelDef for: " << node.op();
162     return Status::OK();
163   }
164 
165   // Check if the output_arg is pinned to Host.
166   for (const string& host_memory_arg : kernel->host_memory_arg()) {
167     if (op->output_arg(output_arg_id).name() == host_memory_arg) {
168       *is_candidate = true;
169       break;
170     }
171   }
172 
173   return Status::OK();
174 }
175 
176 // Checks if a node's input port is Host friendly.
177 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)178 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
179   // If node is on Host, assume its inputs are Host friendly.
180   if (str_util::StrContains(node.device(), DEVICE_CPU)) {
181     return true;
182   }
183 
184   // Check if op's input port is pinned to HostMemory.
185   const OpDef* op = nullptr;
186   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
187   if (!s.ok()) {
188     LOG(WARNING) << "Could not find OpDef for : " << node.op();
189     return false;
190   }
191   const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
192 
193   // Find the kernel.
194   const KernelDef* kernel = nullptr;
195   s = internal::TryFindKernelDef(
196       {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
197   if (!s.ok()) {
198     LOG(INFO) << "Could not find KernelDef for: " << node.op();
199     return false;
200   }
201 
202   // Check if the input_arg is pinned to Host.
203   for (const string& host_memory_arg : kernel->host_memory_arg()) {
204     if (op->input_arg(input_arg_id).name() == host_memory_arg) {
205       return true;
206     }
207   }
208 
209   return false;
210 }
211 
212 // Checks if a node is a candidate to pin to Host.
213 // The rough algorithm is as follows:
214 // 1] Check if node is blacklisted.
215 // 2] Check if node can run on Host.
216 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
217 //    ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)218 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
219                            const NodeDef& node, bool* is_candidate) {
220   *is_candidate = false;
221 
222   // Check if node already on CPU.
223   if (str_util::StrContains(node.device(), DEVICE_CPU)) {
224     *is_candidate = true;
225     return Status::OK();
226   }
227 
228   // Skip these node types.
229   if (IsBlacklisted(node)) {
230     return Status::OK();
231   }
232 
233   // Check the node can be run on CPU.
234   Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
235   if (!s.ok()) {
236     return Status::OK();
237   }
238 
239   // Check all inputs are Host friendly.
240   for (const GraphView::OutputPort& fanin :
241        graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
242     bool fanin_candidate = false;
243     TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
244         graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
245     if (!fanin_candidate) {
246       return Status::OK();
247     }
248   }
249 
250   // Check all outputs are Host friendly.
251   if (!properties->has_properties()) {
252     // This is an expensive call, call it lazily.
253     TF_RETURN_IF_ERROR(properties->InferStatically(
254         /*assume_valid_feeds=*/false));
255   }
256   for (const auto& prop : properties->GetOutputProperties(node.name())) {
257     if (!IsTensorSmall(prop)) {
258       return Status::OK();
259     }
260   }
261 
262   *is_candidate = true;
263   return Status::OK();
264 }
265 
266 // Tries to find a Host device from `devices`. Returns empty string if no
267 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)268 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
269                          bool has_device_cpu, const string& device) {
270   // Force this node onto the CPU.
271   if (device.empty() && has_device_cpu) {
272     return "/device:CPU:0";
273   } else if (str_util::StrContains(device, DEVICE_GPU)) {
274     // Sometimes the cluster can have:
275     //   devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
276     // and we need to handle them properly.
277     for (const auto& device_match :
278          {std::pair<string, string>("GPU", "CPU:0"),
279           std::pair<string, string>("/device", "/device:CPU:0")}) {
280       const string device_host =
281           strings::StrCat(device.substr(0, device.rfind(device_match.first)),
282                           device_match.second);
283       if (devices.find(device_host) != devices.end()) {
284         return device_host;
285       }
286     }
287   }
288 
289   // We couldn't find an appropriate Host device, return no device.
290   return "";
291 }
292 
IsTPUGraphDef(const GraphDef & def)293 bool IsTPUGraphDef(const GraphDef& def) {
294   for (const auto& node : def.node()) {
295     if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
296         node.op() == "TPUPartitionedCall") {
297       return true;
298     }
299   }
300   return false;
301 }
302 }  // end namespace internal
303 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)304 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
305                                     GraphDef* optimized_graph) {
306   *optimized_graph = item.graph;
307 
308   // Skip all TPU graphs.
309   if (internal::IsTPUGraphDef(*optimized_graph)) {
310     return Status::OK();
311   }
312 
313   GraphProperties properties(item);
314   GraphView graph(optimized_graph);
315 
316   gtl::FlatSet<string> devices;
317   if (cluster) {
318     const std::vector<string> device_names = cluster->GetDeviceNames();
319     devices.insert(device_names.begin(), device_names.end());
320   } else {
321     devices = {"/device:CPU:0"};
322   }
323 
324   const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
325 
326   // Topologically sort the graph, so that we traverse the nodes in order. This
327   // will help us discover producer->consumer chains of Host ops.
328   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
329 
330   // All the Const nodes, and their original devices in topological order.
331   std::vector<std::pair<NodeDef*, string>> const_nodes;
332 
333   for (auto& node : *optimized_graph->mutable_node()) {
334     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
335     bool is_candidate = false;
336     TF_RETURN_IF_ERROR(
337         internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
338     if (!is_candidate) {
339       continue;
340     }
341 
342     string device =
343         internal::TryFindHostDevice(devices, has_device_cpu, node.device());
344     if (!device.empty()) {
345       // Keep track of all Const nodes that we swapped.
346       if (IsConstant(node)) {
347         const_nodes.emplace_back(&node, node.device());
348       }
349       *node.mutable_device() = std::move(device);
350     }
351   }
352 
353   // Traverse all `const_nodes`, and map them back to GPU greedily.
354   for (auto& it : const_nodes) {
355     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
356     NodeDef* node = it.first;
357     const string& device = it.second;
358 
359     // Check all the consumers of this node, if any of them are not on CPU, swap
360     // this node back onto the original device.
361     for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
362       // The consumer is not Host friendly, swap it back to the original device.
363       if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
364                                                  fanout.port_id)) {
365         node->set_device(device);
366         break;
367       }
368     }
369   }
370   return Status::OK();
371 }
372 
373 }  // end namespace grappler
374 }  // end namespace tensorflow
375