1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/lib/core/error_codes.pb.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29
30 namespace tensorflow {
31 namespace grappler {
32 namespace internal {
33
34 // TODO(williamchan): Change this constant to be something smarter, maybe
35 // dynamically determined.
36 constexpr int64 kTensorMaxSize = 64;
37
38 // All the nodes that should be blacklisted and not swapped.
IsBlacklisted(const NodeDef & node)39 bool IsBlacklisted(const NodeDef& node) {
40 return
41 // Collective ops should not be swapped.
42 IsCollective(node) ||
43 // ControlFlow ops should not be swapped.
44 IsControlFlow(node) ||
45 // NoOp ops should not be swapped (due to group dependencies).
46 IsNoOp(node);
47 }
48
49 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)50 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
51 if (prop.dtype() == DataType::DT_STRING) {
52 return true;
53 }
54
55 // Check type to be int32 or int64.
56 if (prop.dtype() != DataType::DT_INT32 &&
57 prop.dtype() != DataType::DT_INT64) {
58 return false;
59 }
60
61 // Check size known and small.
62 const int64 size = NumCoefficients(prop.shape());
63 if (size < 0 || size > kTensorMaxSize) {
64 return false;
65 }
66
67 return true;
68 }
69
70 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)71 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
72 const NodeDef& node, const KernelDef** kdef) {
73 for (const DeviceType& device : devices) {
74 const KernelDef* kernel = nullptr;
75 Status s = FindKernelDef(device, node, &kernel, nullptr);
76 if (s.ok()) {
77 if (kdef) {
78 *kdef = kernel;
79 }
80 return Status::OK();
81 }
82 }
83
84 return errors::NotFound("Could not find KernelDef for op: ", node.op());
85 }
86
87 // Checks if a node's output port is host friendly.
88 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)89 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
90 GraphProperties* properties,
91 const NodeDef& node, int port_id,
92 bool* is_candidate) {
93 *is_candidate = false;
94
95 // Make sure we are not a blacklisted op.
96 if (IsBlacklisted(node)) {
97 return Status::OK();
98 }
99
100 // Check to make sure we have the right properties (i.e., statically shaped).
101 if (!properties->has_properties()) {
102 // This is an expensive call, call it lazily.
103 TF_RETURN_IF_ERROR(properties->InferStatically(
104 /*assume_valid_feeds=*/false));
105 }
106 const auto& output_properties = properties->GetOutputProperties(node.name());
107 if (port_id >= output_properties.size()) {
108 LOG(WARNING) << "port_id=" << port_id
109 << " but output_properties.size()=" << output_properties.size()
110 << "\n"
111 << node.DebugString();
112 return Status::OK();
113 }
114 if (!IsTensorSmall(output_properties[port_id])) {
115 return Status::OK();
116 }
117
118 // These nodes may be optimized away downstream (even if pinned to Host), we
119 // should (recusively) check their source.
120 if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
121 for (const auto& fanin : graph.GetFanins(node, false)) {
122 bool fanin_candidate = false;
123 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
124 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
125 if (!fanin_candidate) {
126 return Status::OK();
127 }
128 }
129 *is_candidate = true;
130 return Status::OK();
131 }
132
133 // Check if op's device is on CPU.
134 if (str_util::StrContains(node.device(), DEVICE_CPU)) {
135 *is_candidate = true;
136 return Status::OK();
137 }
138
139 // Check if op's output port is pinned to HostMemory.
140 const OpDef* op = nullptr;
141 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
142 if (!s.ok()) {
143 LOG(WARNING) << "Could not find OpDef for : " << node.op();
144 return Status::OK();
145 }
146
147 // Map the port_id to output_arg_id.
148 const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
149 if (output_arg_id < 0) {
150 LOG(WARNING) << "Invalid port: " << port_id << "!\n"
151 << node.DebugString() << "\n"
152 << op->DebugString();
153 return Status::OK();
154 }
155
156 // Find the kernel.
157 const KernelDef* kernel = nullptr;
158 s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
159 &kernel);
160 if (!s.ok()) {
161 LOG(INFO) << "Could not find KernelDef for: " << node.op();
162 return Status::OK();
163 }
164
165 // Check if the output_arg is pinned to Host.
166 for (const string& host_memory_arg : kernel->host_memory_arg()) {
167 if (op->output_arg(output_arg_id).name() == host_memory_arg) {
168 *is_candidate = true;
169 break;
170 }
171 }
172
173 return Status::OK();
174 }
175
176 // Checks if a node's input port is Host friendly.
177 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)178 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
179 // If node is on Host, assume its inputs are Host friendly.
180 if (str_util::StrContains(node.device(), DEVICE_CPU)) {
181 return true;
182 }
183
184 // Check if op's input port is pinned to HostMemory.
185 const OpDef* op = nullptr;
186 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
187 if (!s.ok()) {
188 LOG(WARNING) << "Could not find OpDef for : " << node.op();
189 return false;
190 }
191 const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
192
193 // Find the kernel.
194 const KernelDef* kernel = nullptr;
195 s = internal::TryFindKernelDef(
196 {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
197 if (!s.ok()) {
198 LOG(INFO) << "Could not find KernelDef for: " << node.op();
199 return false;
200 }
201
202 // Check if the input_arg is pinned to Host.
203 for (const string& host_memory_arg : kernel->host_memory_arg()) {
204 if (op->input_arg(input_arg_id).name() == host_memory_arg) {
205 return true;
206 }
207 }
208
209 return false;
210 }
211
212 // Checks if a node is a candidate to pin to Host.
213 // The rough algorithm is as follows:
214 // 1] Check if node is blacklisted.
215 // 2] Check if node can run on Host.
216 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
217 // ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)218 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
219 const NodeDef& node, bool* is_candidate) {
220 *is_candidate = false;
221
222 // Check if node already on CPU.
223 if (str_util::StrContains(node.device(), DEVICE_CPU)) {
224 *is_candidate = true;
225 return Status::OK();
226 }
227
228 // Skip these node types.
229 if (IsBlacklisted(node)) {
230 return Status::OK();
231 }
232
233 // Check the node can be run on CPU.
234 Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
235 if (!s.ok()) {
236 return Status::OK();
237 }
238
239 // Check all inputs are Host friendly.
240 for (const GraphView::OutputPort& fanin :
241 graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
242 bool fanin_candidate = false;
243 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
244 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
245 if (!fanin_candidate) {
246 return Status::OK();
247 }
248 }
249
250 // Check all outputs are Host friendly.
251 if (!properties->has_properties()) {
252 // This is an expensive call, call it lazily.
253 TF_RETURN_IF_ERROR(properties->InferStatically(
254 /*assume_valid_feeds=*/false));
255 }
256 for (const auto& prop : properties->GetOutputProperties(node.name())) {
257 if (!IsTensorSmall(prop)) {
258 return Status::OK();
259 }
260 }
261
262 *is_candidate = true;
263 return Status::OK();
264 }
265
266 // Tries to find a Host device from `devices`. Returns empty string if no
267 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)268 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
269 bool has_device_cpu, const string& device) {
270 // Force this node onto the CPU.
271 if (device.empty() && has_device_cpu) {
272 return "/device:CPU:0";
273 } else if (str_util::StrContains(device, DEVICE_GPU)) {
274 // Sometimes the cluster can have:
275 // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
276 // and we need to handle them properly.
277 for (const auto& device_match :
278 {std::pair<string, string>("GPU", "CPU:0"),
279 std::pair<string, string>("/device", "/device:CPU:0")}) {
280 const string device_host =
281 strings::StrCat(device.substr(0, device.rfind(device_match.first)),
282 device_match.second);
283 if (devices.find(device_host) != devices.end()) {
284 return device_host;
285 }
286 }
287 }
288
289 // We couldn't find an appropriate Host device, return no device.
290 return "";
291 }
292
IsTPUGraphDef(const GraphDef & def)293 bool IsTPUGraphDef(const GraphDef& def) {
294 for (const auto& node : def.node()) {
295 if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
296 node.op() == "TPUPartitionedCall") {
297 return true;
298 }
299 }
300 return false;
301 }
302 } // end namespace internal
303
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)304 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
305 GraphDef* optimized_graph) {
306 *optimized_graph = item.graph;
307
308 // Skip all TPU graphs.
309 if (internal::IsTPUGraphDef(*optimized_graph)) {
310 return Status::OK();
311 }
312
313 GraphProperties properties(item);
314 GraphView graph(optimized_graph);
315
316 gtl::FlatSet<string> devices;
317 if (cluster) {
318 const std::vector<string> device_names = cluster->GetDeviceNames();
319 devices.insert(device_names.begin(), device_names.end());
320 } else {
321 devices = {"/device:CPU:0"};
322 }
323
324 const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
325
326 // Topologically sort the graph, so that we traverse the nodes in order. This
327 // will help us discover producer->consumer chains of Host ops.
328 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
329
330 // All the Const nodes, and their original devices in topological order.
331 std::vector<std::pair<NodeDef*, string>> const_nodes;
332
333 for (auto& node : *optimized_graph->mutable_node()) {
334 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
335 bool is_candidate = false;
336 TF_RETURN_IF_ERROR(
337 internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
338 if (!is_candidate) {
339 continue;
340 }
341
342 string device =
343 internal::TryFindHostDevice(devices, has_device_cpu, node.device());
344 if (!device.empty()) {
345 // Keep track of all Const nodes that we swapped.
346 if (IsConstant(node)) {
347 const_nodes.emplace_back(&node, node.device());
348 }
349 *node.mutable_device() = std::move(device);
350 }
351 }
352
353 // Traverse all `const_nodes`, and map them back to GPU greedily.
354 for (auto& it : const_nodes) {
355 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
356 NodeDef* node = it.first;
357 const string& device = it.second;
358
359 // Check all the consumers of this node, if any of them are not on CPU, swap
360 // this node back onto the original device.
361 for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
362 // The consumer is not Host friendly, swap it back to the original device.
363 if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
364 fanout.port_id)) {
365 node->set_device(device);
366 break;
367 }
368 }
369 }
370 return Status::OK();
371 }
372
373 } // end namespace grappler
374 } // end namespace tensorflow
375