• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/placer.h"
17 
18 #include <memory>
19 #include <set>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/common_runtime/colocation_graph.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/framework/attr_value_util.h"
27 #include "tensorflow/core/framework/device_attributes.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 #include "tensorflow/core/util/dump_graph.h"
39 #include "tensorflow/core/util/port.h"
40 
41 namespace tensorflow {
42 
43 namespace {
44 
45 // Returns true if the node has no inputs and produces outputs
46 // that are consumed by a single node.
47 //
48 // TODO(vrv): Currently this handles only nodes with one output, but
49 // this could be extended to handle the case where a node has many
50 // outputs that are connected to nodes in the same colocation group.
IsGeneratorNode(const Node * node)51 bool IsGeneratorNode(const Node* node) {
52   return node->num_inputs() == 0 && node->num_outputs() == 1 &&
53          !IsRefType(node->output_type(0));
54 }
55 
LogDeviceAssignment(const Node * node,bool log_device_placement)56 void LogDeviceAssignment(const Node* node, bool log_device_placement) {
57   // Log placement if log_device_placement is set.
58   if (log_device_placement) {
59     printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(),
60            node->assigned_device_name().c_str());
61     LOG(INFO) << node->name() << ": "
62               << "(" << node->type_string() << ")"
63               << node->assigned_device_name();
64   }
65 }
66 
AssignAndLog(int assigned_device,Node * node,ColocationGraph * colocation_graph,bool log_device_placement)67 Status AssignAndLog(int assigned_device, Node* node,
68                     ColocationGraph* colocation_graph,
69                     bool log_device_placement) {
70   node->set_assigned_device_name_index(assigned_device);
71 
72   // Constraint the group of node to the assigned device.
73   TF_RETURN_IF_ERROR(colocation_graph->LimitToAssignedDevice(*node));
74 
75   LogDeviceAssignment(node, log_device_placement);
76   return Status::OK();
77 }
78 
79 }  // namespace
80 
Placer(Graph * graph,const DeviceSet * devices,const Device * default_device,bool allow_soft_placement,bool log_device_placement)81 Placer::Placer(Graph* graph, const DeviceSet* devices,
82                const Device* default_device, bool allow_soft_placement,
83                bool log_device_placement)
84     : graph_(graph),
85       devices_(devices),
86       default_device_(default_device),
87       allow_soft_placement_(allow_soft_placement),
88       log_device_placement_(log_device_placement) {}
89 
Placer(Graph * graph,const DeviceSet * devices,const Device * default_device)90 Placer::Placer(Graph* graph, const DeviceSet* devices,
91                const Device* default_device)
92     : Placer(graph, devices, default_device, true, false) {}
93 
Placer(Graph * graph,const DeviceSet * devices)94 Placer::Placer(Graph* graph, const DeviceSet* devices)
95     : Placer(graph, devices, nullptr, true, false) {}
96 
~Placer()97 Placer::~Placer() {}
98 
Run()99 Status Placer::Run() {
100   if (devices_->devices().empty()) {
101     return errors::FailedPrecondition("No devices are registered");
102   }
103 
104   if (VLOG_IS_ON(3)) {
105     DumpGraphToFile("placer_input", *graph_, nullptr);
106     for (const Node* node : graph_->op_nodes()) {
107       VLOG(3) << "    " << node->name() << ": requested: '"
108               << node->requested_device() << "' assigned: '"
109               << node->assigned_device_name() << "'";
110     }
111   }
112 
113   ColocationGraph colocation_graph(graph_, devices_, default_device_,
114                                    allow_soft_placement_,
115                                    log_device_placement_);
116 
117   TF_RETURN_IF_ERROR(colocation_graph.Initialize());
118 
119   // For each node, assign a device based on the constraints in the disjoint
120   // node set.
121   std::vector<Node*> second_pass;
122   for (Node* node : graph_->op_nodes()) {
123     // The graph may have come pre-populated by the framework with assigned
124     // devices (e.g., for stateful placements), so the placer should not try to
125     // place nodes that are already placed.
126     if (node->has_assigned_device_name()) {
127       TF_RETURN_IF_ERROR(colocation_graph.LimitToAssignedDevice(*node));
128       LogDeviceAssignment(node, log_device_placement_);
129       continue;
130     }
131 
132     // Heuristic A: prefer to place "generators" with their only
133     // consumers.
134     //
135     // If this is a node with no inputs and one output, we save
136     // this for a second pass, so that the consumer's placement
137     // is chosen.
138     if (IsGeneratorNode(node)) {
139       second_pass.push_back(node);
140       continue;
141     }
142 
143     const std::vector<Device*>* devices;
144     Status status = colocation_graph.GetDevicesForNode(node, &devices);
145     if (!status.ok()) {
146       return AttachDef(
147           errors::InvalidArgument("Cannot assign a device for operation ",
148                                   node->name(), ": ", status.error_message()),
149           *node);
150     }
151 
152     // Returns the first device in sorted devices list so we will always
153     // choose the same device.
154     //
155     // TODO(vrv): Factor this assignment out into a pluggable
156     // algorithm, so that Placer is responsible for enforcing
157     // preconditions and we can experiment with other algorithms when
158     // given a choice of devices. Once we have a better idea of the
159     // types of heuristics we want to use and the information needed
160     // to perform good placement we can add an interface for this.
161     int assigned_device = -1;
162 
163     // Heuristic B: If the node only operates on metadata, not data,
164     // then it is desirable to place that metadata node with its
165     // input.
166     if (IsMetadata(node)) {
167       // Make sure that the input device type is in the list of supported
168       // device types for this node.
169       const Node* input = (*node->in_edges().begin())->src();
170       // TODO(vrv): if the input is empty, consider postponing this
171       // node's assignment to the second pass, so that we handle the
172       // case where a metadata node's input comes from a backedge
173       // of a loop.
174       if (CanAssignToDevice(input->assigned_device_name(), *devices)) {
175         assigned_device = input->assigned_device_name_index();
176       }
177     }
178 
179     // Provide the default, if necessary.
180     if (assigned_device == -1) {
181       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
182     }
183 
184     TF_RETURN_IF_ERROR(AssignAndLog(assigned_device, node, &colocation_graph,
185                                     log_device_placement_));
186   }
187 
188   // Perform a second pass assignment for those nodes explicitly
189   // skipped during the first pass.
190   for (Node* node : second_pass) {
191     const std::vector<Device*>* devices;
192     Status status = colocation_graph.GetDevicesForNode(node, &devices);
193     if (!status.ok()) {
194       return AttachDef(
195           errors::InvalidArgument("Cannot assign a device for operation ",
196                                   node->name(), ": ", status.error_message()),
197           *node);
198     }
199 
200     int assigned_device = -1;
201 
202     // Heuristic A application.
203     if (IsGeneratorNode(node) && !node->out_edges().empty()) {
204       const Node* output = (*node->out_edges().begin())->dst();
205       int output_device_name = output->assigned_device_name_index();
206 
207       const bool consumers_on_same_device = std::all_of(
208           node->out_edges().begin(), node->out_edges().end(),
209           [output_device_name](const Edge* e) {
210             return e->dst()->assigned_device_name_index() == output_device_name;
211           });
212 
213       if (consumers_on_same_device &&
214           CanAssignToDevice(output->assigned_device_name(), *devices)) {
215         assigned_device = output_device_name;
216       }
217     }
218 
219     // Provide the default, if necessary.
220     if (assigned_device == -1) {
221       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
222     }
223 
224     TF_RETURN_IF_ERROR(AssignAndLog(assigned_device, node, &colocation_graph,
225                                     log_device_placement_));
226   }
227 
228   if (VLOG_IS_ON(3)) {
229     DumpGraphToFile("placer_output", *graph_, nullptr);
230   }
231   return Status::OK();
232 }
233 
CanAssignToDevice(const string & candidate_device_name,const std::vector<Device * > & devices) const234 bool Placer::CanAssignToDevice(const string& candidate_device_name,
235                                const std::vector<Device*>& devices) const {
236   if (!candidate_device_name.empty()) {
237     // 'devices' lists the set of devices that the placer or the user has
238     // constrained the operation to.  "candidate_device_name" must
239     // refer to a concrete Device that is in the list of 'devices'.
240     const Device* other_device =
241         devices_->FindDeviceByName(candidate_device_name);
242     if (std::find(devices.begin(), devices.end(), other_device) !=
243         devices.end()) {
244       return true;
245     }
246   }
247 
248   return false;
249 }
250 
251 }  // namespace tensorflow
252