• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/util/device_name_utils.h"
27 #include "tensorflow/core/util/port.h"
28 
29 namespace tensorflow {
30 
31 // Represents a node in the disjoint node forest and the
32 // accumulated constraints on the device used by that node.
33 class Member {
34  public:
35   Member() = default;
36 
37   Status SetParentAndSupportedDevices(const Node& node,
38                                       const std::vector<DeviceType>& types);
39 
requested_device_name()40   const DeviceNameUtils::ParsedName& requested_device_name() const {
41     return requested_device_name_;
42   }
43 
44   Status SetAssignedDeviceName(const string& device_name);
45 
46   Status SetRequestedDeviceName(const Node& node);
47 
48   Status EnsureCompatibilityAcrossResourceEdge(
49       const Node& src, const Member& src_root,
50       const Node& dst, /*dst_root is this*/
51       bool log_device_placement);
52 
supported_device_types()53   const PrioritizedDeviceTypeVector& supported_device_types() const {
54     return supported_device_types_;
55   }
56 
57   // If `dry_run` is true, just sets `new_root` and `old_root` and does not
58   // actually modify anything in the `tree`.
59   static void Merge(std::vector<Member>* tree, int x_root, int y_root,
60                     Member** new_root, Member** old_root, bool dry_run);
61 
62   // tree is non-const because we can change some `parent` pointers in some
63   // members for more efficient future lookups. The vector itself is not
64   // changed.
65   static int FindRoot(std::vector<Member>* tree, int node_id);
66 
67   Status MergeDeviceNames(const Member& other, bool allow_soft_placement);
68 
69   // Updates this to contain the intersection of the device types in
70   // this and "other". If the intersection is empty, returns false and does
71   // not update this. Else returns true and updates this.
72   bool MergeSupportedDevices(const Member& other);
73 
74   Status AssignDevice(const Node& node, bool allow_soft_placement);
75 
set_possible_devices(std::vector<Device * > && devices)76   void set_possible_devices(std::vector<Device*>&& devices) {
77     possible_devices_ = devices;
78   }
possible_devices()79   const std::vector<Device*>& possible_devices() { return possible_devices_; }
80 
81   string DebugString();
82 
83  private:
84   // The id of the node that is the parent of this one, or its own
85   // id if it is a root. parent <= 0 indicates that this member is invalid.
86   int parent_ = -1;
87 
88   // A proxy for the depth of the tree that is used to prefer
89   // connecting smaller trees to larger trees when merging disjoint
90   // sets.
91   int rank_ = 0;
92 
93   // Once colocation groups have been formed, the Placer starts actually
94   // choosing devices. All nodes in a group must be assigned to the same
95   // device. Once we assigned the first device to some node in this group,
96   // we set assigned_device_name_index to this device name's index in the
97   // graph.
98   // The `*_device_name_` fields will contain the parsed name of this device
99   // and `possible_devices`, if computed, will contain just this device.
100   // `assigned_device_name_index` is an optimization to avoid parsing and
101   // comparing device names. The value of -1 signals that a single device
102   // has not been chosen yet.
103   int assigned_device_name_index_ = -1;
104 
105   // The merged form of the device requested for this node, with those of all of
106   // its children. requested_device_name_ is always kept a specialization (i.e.
107   // DeviceNameUtils::IsSpecialization) of assigned_device_name_. When no device
108   // is requested, this field is set to assigned_device_name_.  As a
109   // specialization of assigned_device_name_, requested_device_name_ represents
110   // the most specific form of all assigned and requested devices of this node
111   // and its children, if this node is a root. requested_device_name_ is used
112   // to finally select devices for nodes.  We can override requested devices due
113   // to resource colocation constraints but not assigned devices (unless soft
114   // placement is on).
115   DeviceNameUtils::ParsedName requested_device_name_;
116 
117   // The merged form of the device assigned for this node, with
118   // those of all of its children.
119   // This field is used to raise errors due to unsatisfiable constraints.
120   // Can be a partial specification.
121   // INVARIANT: requested_device_name_ is always a
122   // DeviceNameUtils::IsSpecialization of assigned_device_name_.
123   DeviceNameUtils::ParsedName assigned_device_name_;
124 
125   // The intersection of all device types supported by this node,
126   // and those of all of its children, in priority order
127   // of the preferred device.
128   PrioritizedDeviceTypeVector supported_device_types_;
129 
130   // If this node is a root, stores a list of Devices to which this node
131   // and all of its children have been assigned, or nullptr if this
132   // has not yet been computed.
133   std::vector<Device*> possible_devices_;
134 };  // namespace
135 
136 // This class maintains the connected components of a colocation
137 // constraint graph, and uses this information to assign a satisfying
138 // device placement to the nodes of the graph.
139 //
140 // The typical usage pattern is:
141 //
142 //   Graph graph = ...;
143 //   DeviceSet device_set = ...;
144 //   ColocationGraph colocation_graph(graph, device_set);
145 //
146 //   // Add all the nodes of the `graph` to the `colocation_graph`.
147 //   for (Node* node : graph.nodes()) {
148 //     TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
149 //   }
150 //
151 //   // Add one or more colocation constraints.
152 //   Node node_1 = *graph.FindNodeId(...);
153 //   Node node_2 = *graph.FindNodeId(...);
154 //   TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
155 //
156 //   // Assign devices based on the accumulated constraints.
157 //   for (Node* node : graph.nodes()) {
158 //     TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
159 //   }
160 //
161 // This implementation uses the Union-Find algorithm to efficiently maintain the
162 // connected components and incrementally adds edges via
163 // ColocationGraph::ColocateNodes() invocations.
164 //
165 // ColocationGraph does not assign any devices to graph nodes. The
166 // `log_device_placement` argument is used to log messages when requested
167 // device is ignored.
168 class ColocationGraph {
169  public:
170   ColocationGraph(const Graph* graph, const DeviceSet* device_set,
171                   const Device* default_device, bool allow_soft_placement,
172                   bool log_device_placement);
173 
174   // Adds each node of the Graph to this ColocationGraph as a singleton.
175   //
176   // NOTE: The implementation assumes that the ids of nodes passed to
177   // this method are dense and zero-based; the memory used will be linear in
178   // the largest node ID.
179   // NOTE: If this method returns an error, *this is left in an undefined
180   // state.
181   Status ColocateAllNodes();
182 
183   Status ColocateResourceOrRefEdge(Node* src, Node* dst);
184 
185   Status ColocateResourceAndRefEdges();
186 
187   Status Initialize();
188 
189   Status ColocateNodeToGroup(
190       std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
191           colocation_group_root,
192       const Node* node, StringPiece colocation_group);
193 
194   // Merge the (possibly disjoint) sets containing nodes "x" and
195   // "y". Returns OK if the all nodes in the union of these sets can
196   // be placed on the same device type.
197   //
198   // If this method returns an error, *this is unchanged.
199   Status ColocateNodes(const Node& x, const Node& y);
200 
201   // This overload of ColocateNodes() allows a caller to provide the root node
202   // ids for the two nodes. For large graphs, this noticeably reduces the
203   // graph load time.
204   // If this method returns an error, *this is unchanged.
205   Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root);
206 
207   // Limits the possible devices of `node`'s colocation group to the device
208   // to which `node` is assigned. This makes sure that all nodes in this
209   // colocation group will be assigned to the same device. Without this
210   // explicit restriction, heuristics can choose a different possible device
211   // for other nodes in the group.
212   Status LimitToAssignedDevice(const Node& node);
213 
214   // For the given node, subject to the constraints previously given
215   // to this ColocationGraph, set its assigned_device_name. Returns OK
216   // if a satisfying device can be found, otherwise an error.
217   //
218   // Note: This method returns a pointer to a field within members_.
219   // The caller must not use the returned pointer after there is any possibility
220   // that the members_[i].possible_devices field has been modified.
221   Status GetDevicesForNode(Node* node,
222                            const std::vector<Device*>** possible_devices);
223 
224   Status InitializeMembers();
225 
226   string DebugString();
227 
228   // Returns debugging info for the node referred to by 'node_root'.
229   string DebugInfo(const int node_root);
230 
231   Status InitializeMemberWithAssignedDevice(const string& assigned_device_name,
232                                             const string& node_type,
233                                             bool must_be_full_name,
234                                             Member* member);
235 
236   Status InitializeMember(const Node& node, Member* member);
237 
238   // Returns the root node of the disjoint tree to which the node with the
239   // given id is connected.
FindRoot(int node_id)240   int FindRoot(int node_id) { return Member::FindRoot(&members_, node_id); }
241 
242   const Graph* const graph_;  // Not owned.
243   std::vector<Member> members_;
244   const DeviceSet* device_set_;  // Not owned.
245   const std::vector<DeviceType> device_types_;
246   const Device* default_device_;
247   const bool allow_soft_placement_;
248   const bool log_device_placement_;
249 };
250 
251 }  // namespace tensorflow
252 
253 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
254