1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "absl/strings/str_join.h" 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/stringpiece.h" 26 #include "tensorflow/core/util/device_name_utils.h" 27 #include "tensorflow/core/util/port.h" 28 29 namespace tensorflow { 30 31 // Represents a node in the disjoint node forest and the 32 // accumulated constraints on the device used by that node. 33 class Member { 34 public: 35 Member() = default; 36 37 Status SetParentAndSupportedDevices(const Node& node, 38 const std::vector<DeviceType>& types); 39 requested_device_name()40 const DeviceNameUtils::ParsedName& requested_device_name() const { 41 return requested_device_name_; 42 } 43 44 Status SetAssignedDeviceName(const string& device_name); 45 46 Status SetRequestedDeviceName(const Node& node); 47 48 Status EnsureCompatibilityAcrossResourceEdge( 49 const Node& src, const Member& src_root, 50 const Node& dst, /*dst_root is this*/ 51 bool log_device_placement); 52 supported_device_types()53 const PrioritizedDeviceTypeVector& supported_device_types() const { 54 return supported_device_types_; 55 } 56 57 // If `dry_run` is true, just sets `new_root` and `old_root` and does not 58 // actually modify anything in the `tree`. 59 static void Merge(std::vector<Member>* tree, int x_root, int y_root, 60 Member** new_root, Member** old_root, bool dry_run); 61 62 // tree is non-const because we can change some `parent` pointers in some 63 // members for more efficient future lookups. The vector itself is not 64 // changed. 65 static int FindRoot(std::vector<Member>* tree, int node_id); 66 67 Status MergeDeviceNames(const Member& other, bool allow_soft_placement); 68 69 // Updates this to contain the intersection of the device types in 70 // this and "other". If the intersection is empty, returns false and does 71 // not update this. Else returns true and updates this. 72 bool MergeSupportedDevices(const Member& other); 73 74 Status AssignDevice(const Node& node, bool allow_soft_placement); 75 set_possible_devices(std::vector<Device * > && devices)76 void set_possible_devices(std::vector<Device*>&& devices) { 77 possible_devices_ = devices; 78 } possible_devices()79 const std::vector<Device*>& possible_devices() { return possible_devices_; } 80 81 string DebugString(); 82 83 private: 84 // The id of the node that is the parent of this one, or its own 85 // id if it is a root. parent <= 0 indicates that this member is invalid. 86 int parent_ = -1; 87 88 // A proxy for the depth of the tree that is used to prefer 89 // connecting smaller trees to larger trees when merging disjoint 90 // sets. 91 int rank_ = 0; 92 93 // Once colocation groups have been formed, the Placer starts actually 94 // choosing devices. All nodes in a group must be assigned to the same 95 // device. Once we assigned the first device to some node in this group, 96 // we set assigned_device_name_index to this device name's index in the 97 // graph. 98 // The `*_device_name_` fields will contain the parsed name of this device 99 // and `possible_devices`, if computed, will contain just this device. 100 // `assigned_device_name_index` is an optimization to avoid parsing and 101 // comparing device names. The value of -1 signals that a single device 102 // has not been chosen yet. 103 int assigned_device_name_index_ = -1; 104 105 // The merged form of the device requested for this node, with those of all of 106 // its children. requested_device_name_ is always kept a specialization (i.e. 107 // DeviceNameUtils::IsSpecialization) of assigned_device_name_. When no device 108 // is requested, this field is set to assigned_device_name_. As a 109 // specialization of assigned_device_name_, requested_device_name_ represents 110 // the most specific form of all assigned and requested devices of this node 111 // and its children, if this node is a root. requested_device_name_ is used 112 // to finally select devices for nodes. We can override requested devices due 113 // to resource colocation constraints but not assigned devices (unless soft 114 // placement is on). 115 DeviceNameUtils::ParsedName requested_device_name_; 116 117 // The merged form of the device assigned for this node, with 118 // those of all of its children. 119 // This field is used to raise errors due to unsatisfiable constraints. 120 // Can be a partial specification. 121 // INVARIANT: requested_device_name_ is always a 122 // DeviceNameUtils::IsSpecialization of assigned_device_name_. 123 DeviceNameUtils::ParsedName assigned_device_name_; 124 125 // The intersection of all device types supported by this node, 126 // and those of all of its children, in priority order 127 // of the preferred device. 128 PrioritizedDeviceTypeVector supported_device_types_; 129 130 // If this node is a root, stores a list of Devices to which this node 131 // and all of its children have been assigned, or nullptr if this 132 // has not yet been computed. 133 std::vector<Device*> possible_devices_; 134 }; // namespace 135 136 // This class maintains the connected components of a colocation 137 // constraint graph, and uses this information to assign a satisfying 138 // device placement to the nodes of the graph. 139 // 140 // The typical usage pattern is: 141 // 142 // Graph graph = ...; 143 // DeviceSet device_set = ...; 144 // ColocationGraph colocation_graph(graph, device_set); 145 // 146 // // Add all the nodes of the `graph` to the `colocation_graph`. 147 // for (Node* node : graph.nodes()) { 148 // TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node)); 149 // } 150 // 151 // // Add one or more colocation constraints. 152 // Node node_1 = *graph.FindNodeId(...); 153 // Node node_2 = *graph.FindNodeId(...); 154 // TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2)); 155 // 156 // // Assign devices based on the accumulated constraints. 157 // for (Node* node : graph.nodes()) { 158 // TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node)); 159 // } 160 // 161 // This implementation uses the Union-Find algorithm to efficiently maintain the 162 // connected components and incrementally adds edges via 163 // ColocationGraph::ColocateNodes() invocations. 164 // 165 // ColocationGraph does not assign any devices to graph nodes. The 166 // `log_device_placement` argument is used to log messages when requested 167 // device is ignored. 168 class ColocationGraph { 169 public: 170 ColocationGraph(const Graph* graph, const DeviceSet* device_set, 171 const Device* default_device, bool allow_soft_placement, 172 bool log_device_placement); 173 174 // Adds each node of the Graph to this ColocationGraph as a singleton. 175 // 176 // NOTE: The implementation assumes that the ids of nodes passed to 177 // this method are dense and zero-based; the memory used will be linear in 178 // the largest node ID. 179 // NOTE: If this method returns an error, *this is left in an undefined 180 // state. 181 Status ColocateAllNodes(); 182 183 Status ColocateResourceOrRefEdge(Node* src, Node* dst); 184 185 Status ColocateResourceAndRefEdges(); 186 187 Status Initialize(); 188 189 Status ColocateNodeToGroup( 190 std::unordered_map<StringPiece, const Node*, StringPieceHasher>* 191 colocation_group_root, 192 const Node* node, StringPiece colocation_group); 193 194 // Merge the (possibly disjoint) sets containing nodes "x" and 195 // "y". Returns OK if the all nodes in the union of these sets can 196 // be placed on the same device type. 197 // 198 // If this method returns an error, *this is unchanged. 199 Status ColocateNodes(const Node& x, const Node& y); 200 201 // This overload of ColocateNodes() allows a caller to provide the root node 202 // ids for the two nodes. For large graphs, this noticeably reduces the 203 // graph load time. 204 // If this method returns an error, *this is unchanged. 205 Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root); 206 207 // Limits the possible devices of `node`'s colocation group to the device 208 // to which `node` is assigned. This makes sure that all nodes in this 209 // colocation group will be assigned to the same device. Without this 210 // explicit restriction, heuristics can choose a different possible device 211 // for other nodes in the group. 212 Status LimitToAssignedDevice(const Node& node); 213 214 // For the given node, subject to the constraints previously given 215 // to this ColocationGraph, set its assigned_device_name. Returns OK 216 // if a satisfying device can be found, otherwise an error. 217 // 218 // Note: This method returns a pointer to a field within members_. 219 // The caller must not use the returned pointer after there is any possibility 220 // that the members_[i].possible_devices field has been modified. 221 Status GetDevicesForNode(Node* node, 222 const std::vector<Device*>** possible_devices); 223 224 Status InitializeMembers(); 225 226 string DebugString(); 227 228 // Returns debugging info for the node referred to by 'node_root'. 229 string DebugInfo(const int node_root); 230 231 Status InitializeMemberWithAssignedDevice(const string& assigned_device_name, 232 const string& node_type, 233 bool must_be_full_name, 234 Member* member); 235 236 Status InitializeMember(const Node& node, Member* member); 237 238 // Returns the root node of the disjoint tree to which the node with the 239 // given id is connected. FindRoot(int node_id)240 int FindRoot(int node_id) { return Member::FindRoot(&members_, node_id); } 241 242 const Graph* const graph_; // Not owned. 243 std::vector<Member> members_; 244 const DeviceSet* device_set_; // Not owned. 245 const std::vector<DeviceType> device_types_; 246 const Device* default_device_; 247 const bool allow_soft_placement_; 248 const bool log_device_placement_; 249 }; 250 251 } // namespace tensorflow 252 253 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 254