1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "absl/strings/str_join.h" 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/common_runtime/inspecting_placer.h" 25 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/core/stringpiece.h" 29 #include "tensorflow/core/util/device_name_utils.h" 30 #include "tensorflow/core/util/port.h" 31 32 namespace tensorflow { 33 34 // Represents a node in the disjoint node forest and the 35 // accumulated constraints on the device used by that node. 36 class Member { 37 public: 38 Member() = default; 39 40 Status SetParentAndSupportedDevices( 41 const Node& node, const std::vector<DeviceType>& types, 42 const DeviceNameUtils::ParsedName* local_address_spec); 43 requested_device_name()44 const DeviceNameUtils::ParsedName& requested_device_name() const { 45 return requested_device_name_; 46 } 47 48 Status SetAssignedDeviceName(const string& device_name); 49 Status SetResourceDeviceName(const Node& node); 50 Status SetRequestedDeviceName(const Node& node); 51 52 Status FillPossibleDevices(PossibleDevices* possible_device) const; 53 54 Status EnsureCompatibilityAcrossResourceEdge( 55 const Node& src, const Member& src_root, 56 const Node& dst, /*dst_root is this*/ 57 bool log_device_placement); 58 supported_device_types()59 const PrioritizedDeviceTypeVector& supported_device_types() const { 60 return supported_device_types_; 61 } 62 63 // If `dry_run` is true, just sets `new_root` and `old_root` and does not 64 // actually modify anything in the `tree`. 65 static void Merge(std::vector<Member>* tree, int x_root, int y_root, 66 Member** new_root, Member** old_root, bool dry_run); 67 68 // Returns the root node of the disjoint tree to which the node with the 69 // given id is connected. 70 // FindRoot should be called only for debugging or after the members have 71 // been updated with direct root pointers because it does not update 72 // root pointers and can traverse many links. It exists to have 73 // a const version of FindAndUpdateRoot 74 static int FindRoot(const std::vector<Member>& tree, int node_id); 75 static int FindAndUpdateRoot(std::vector<Member>* tree, int node_id); 76 77 Status MergeDeviceNames(const Member& other, bool allow_soft_placement); 78 79 // Updates this to contain the intersection of the device types in 80 // this and "other". If the intersection is empty, returns false and does 81 // not update this. Else returns true and updates this. 82 bool MergeSupportedDevices(const Member& other); 83 84 Status AssignDevice(const Node& node); 85 86 // If user does not explicitly request XLA device and non-XLA device is 87 // supported for this node, use only the non-XLA device. See b/140896502. 88 void MaybeExcludeXlaDevices(); 89 90 // Limit the possible devices of this (should be a root) to the device 91 // specifications in `devices`. 92 Status LimitToPossibleDevices(const PossibleDevices& devices, 93 bool allow_soft_placement); 94 set_possible_devices(std::vector<Device * > && devices)95 void set_possible_devices(std::vector<Device*>&& devices) { 96 possible_devices_ = devices; 97 } possible_devices()98 const std::vector<Device*>& possible_devices() { return possible_devices_; } 99 100 // Returns a (parsed) device name that is based on requested_device_name() 101 // but with potentially cleared device type and ID fields. A field is cleared 102 // if the assigned_device_name does not specify it. If it does, the field 103 // is not cleared because soft placement cannot violate assigned device names. 104 DeviceNameUtils::ParsedName GetSoftDeviceName() const; 105 106 // Same as GetSoftDeviceName but device type and device ID fields are not 107 // cleared if resource device has them set. 108 DeviceNameUtils::ParsedName GetPreferredSoftDeviceName() const; 109 110 string DebugString() const; 111 112 private: 113 // Updates this to contain the intersection of the device types in 114 // this and `other_devices`. 115 bool MergeSupportedDevices(const PrioritizedDeviceTypeVector& other_devices); 116 117 // The id of the node that is the parent of this one, or its own 118 // id if it is a root. parent <= 0 indicates that this member is invalid. 119 int parent_ = -1; 120 121 // A proxy for the depth of the tree that is used to prefer 122 // connecting smaller trees to larger trees when merging disjoint 123 // sets. 124 int rank_ = 0; 125 126 // Once colocation groups have been formed, the Placer starts actually 127 // choosing devices. All nodes in a group must be assigned to the same 128 // device. Once we assigned the first device to some node in this group, 129 // we set assigned_device_name_index to this device name's index in the 130 // graph. 131 // The `*_device_name_` fields will contain the parsed name of this device 132 // and `possible_devices`, if computed, will contain just this device. 133 // `assigned_device_name_index` is an optimization to avoid parsing and 134 // comparing device names. The value of -1 signals that a single device 135 // has not been chosen yet. 136 int assigned_device_name_index_ = -1; 137 138 // The merged form of the device requested for this node, with those of all of 139 // its children. requested_device_name_ is always kept a specialization (i.e. 140 // DeviceNameUtils::IsSpecification) of assigned_device_name_. When no device 141 // is requested, this field is set to assigned_device_name_. As a 142 // specialization of assigned_device_name_, requested_device_name_ represents 143 // the most specific form of all assigned and requested devices of this node 144 // and its children, if this node is a root. requested_device_name_ is used 145 // to finally select devices for nodes. We can override requested devices due 146 // to resource colocation constraints but not assigned devices (unless soft 147 // placement is on). 148 // INVARIANT: requested_device_name_ is always kept a 149 // DeviceNameUtils::IsSpecification of assigned_device_name_ and 150 // resource_device_name_. This makes requested_device_name_ the "accumulation 151 // of all wishes" about the device. 152 DeviceNameUtils::ParsedName requested_device_name_; 153 154 // The merged form of the device assigned for this node, with 155 // those of all of its children. 156 // This field is used to raise errors due to unsatisfiable constraints. 157 // Can be a partial specification. 158 DeviceNameUtils::ParsedName assigned_device_name_; 159 160 // The merged form of the requested resource device assigned for this node, 161 // with those of all of its children. 162 // This field is used to raise errors due to unsatisfiable constraints. 163 // Can be a partial specification. 164 // resource_device_name_ is initialized with user-requested device on nodes 165 // producing resources, e.g. VarHandleOp. 166 // For historical reasons, with soft placement enabled, Placer can "move" 167 // resources (place resource producing ops on a device different from what 168 // the user explicitly requested) when the colocation group of a resource 169 // producing op contains ops that are not supported on the user-requested 170 // resource device. A classic example of this is a sparse optimizer (only 171 // supported on CPU) used on a GPU variable. In this case, the whole group 172 // will be assigned to some device supported by all ops in the colocation 173 // group. This is a surprising and unfortunate behavior because: 174 // 1. Since soft_placement is on by default, users don't know that their 175 // variables are created on a different device than what they requested. 176 // Among other things, this can lead to surprising poor performance. 177 // 2. Eager runtime cannot "move" resources. The same code can "work" when 178 // wrapped in tf.function but will fail when run eagerly. 179 // 3. Extra complexity here to preserve these resource moving capabilities. 180 DeviceNameUtils::ParsedName resource_device_name_; 181 182 // The intersection of all device types supported by this node, 183 // and those of all of its children, in priority order 184 // of the preferred device. 185 // It is possible that supported_device_types_ has an empty intersection with 186 // requested/assigned/resource devices. We could have detected such cases 187 // as soon as they happen and raise an error. Instead, for historical reasons, 188 // we leave such error detection to the final device picking stage. 189 PrioritizedDeviceTypeVector supported_device_types_; 190 191 // If this node is a root, stores a list of Devices to which this node 192 // and all of its children can be assigned. 193 // `possible_devices` is empty if they have not yet been computed. 194 std::vector<Device*> possible_devices_; 195 }; 196 197 // This class maintains the connected components of a colocation 198 // constraint graph, and uses this information to assign a satisfying 199 // device placement to the nodes of the graph. 200 // 201 // This implementation uses the Union-Find algorithm to efficiently maintain the 202 // connected components and incrementally adds edges via 203 // ColocationGraph::ColocateNodes() invocations. 204 // 205 // ColocationGraph does not assign any devices to graph nodes. The 206 // `log_device_placement` argument is used to log messages when requested 207 // device is ignored. 208 class ColocationGraph { 209 public: 210 // graph, flib_def, and device_set must not be null and must outlive 211 // this ColocationGraph. default_local_device can be null. If not, must 212 // outlive this. 213 ColocationGraph(const Graph* graph, const FunctionStack& stack, 214 const FunctionLibraryDefinition* flib_def, 215 const DeviceSet* device_set, 216 const Device* default_local_device, bool allow_soft_placement, 217 bool log_device_placement); 218 219 Status Initialize(); 220 members()221 const std::vector<Member>& members() const { return members_; } 222 223 // Limit the group containing `node` to the device specifications in 224 // `devices`. 225 Status LimitToPossibleDevices(const Node& node, 226 const PossibleDevices& devices); 227 228 // Limits the possible devices of `node`'s colocation group to the device 229 // to which `node` is assigned. This makes sure that all nodes in this 230 // colocation group will be assigned to the same device. Without this 231 // explicit restriction, heuristics can choose a different possible device 232 // for other nodes in the group. 233 Status LimitToAssignedDevice(const Node& node); 234 235 // Returns the root node of the disjoint tree to which the node with the 236 // given id is connected. 237 // Updates the internal pointers so that future calls will returns faster. FindAndUpdateRoot(int node_id)238 int FindAndUpdateRoot(int node_id) { 239 return Member::FindAndUpdateRoot(&members_, node_id); 240 } 241 242 // For the given node, subject to the constraints previously given 243 // to this ColocationGraph, set its assigned_device_name. Returns OK 244 // if a satisfying device can be found, otherwise an error. 245 // 246 // Note: This method returns a pointer to a field within members_. 247 // The caller must not use the returned pointer after there is any possibility 248 // that the members_[i].possible_devices field has been modified. 249 Status GetDevicesForNode(Node* node, 250 const std::vector<Device*>** possible_devices); 251 252 // Returns debugging info for the node referred to by 'node_root'. 253 string DebugInfo(const int node_root) const; 254 255 string DebugString() const; 256 257 // Returns a list of devices having type in supported_device_types. The 258 // returned list is sorted by preferred type (higher numeric type is 259 // preferred). 260 static std::vector<Device*> FilterSupportedDevices( 261 const std::vector<Device*>& devices, 262 const PrioritizedDeviceTypeVector& supported_device_types, 263 const Device* default_local_device); 264 265 private: 266 // Adds each node of the Graph to this ColocationGraph as a singleton. 267 // 268 // NOTE: The implementation assumes that the ids of nodes passed to 269 // this method are dense and zero-based; the memory used will be linear in 270 // the largest node ID. 271 // NOTE: If this method returns an error, *this is left in an undefined 272 // state. 273 Status ColocateAllNodes(); 274 275 Status ColocateResourceOrRefEdge(const Node* src, const Node* dst); 276 277 // Updates this ColocationGraph by making sure that all nodes 278 // touching resource and/or ref tensors are colocated. 279 // As it iterates over the edges, fills the `inspection_required` set with 280 // the nodes that 281 // PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired 282 // deems as requiring deep inspection by placer. This is an optimization. 283 Status ColocateResourceAndRefEdges( 284 std::unordered_set<Node*>* inspection_required); 285 286 Status AddInspectionConstraints( 287 const std::unordered_set<Node*>& inspection_required); 288 289 // Applies colocation groups for `node`'s inputs and outputs to this 290 // ColocationGraph. 291 // `groups` are the colocation groups to which `nodes`'s inputs and outputs 292 // belong. 293 // `node` is a node requiring deep inspection (e.g. a node calling 294 // a function) 295 // 296 // For example, consider a `node` taking two inputs and producing one output 297 // a b 298 // | | 299 // v v 300 // node 301 // | 302 // v 303 // c 304 // 305 // `groups` can tell us that `a` and `c` must be colocated and their device 306 // must be a GPU. `b` might be in a group by itself without any device 307 // restrictions. 308 // 309 // ApplyIOColocationGroups will have an effect of calling 310 // ColocateNodes(a, c) and LimitToPossibleDevices(`a`, "GPU"). The colocation 311 // group of the `node` itself is not directly impacted. 312 // 313 Status ApplyIOColocationGroups(const IOColocationGroups& groups, 314 const Node& node); 315 316 Status ColocateNodeToGroup( 317 std::unordered_map<StringPiece, const Node*, StringPieceHasher>* 318 colocation_group_root, 319 const Node* node, StringPiece colocation_group); 320 321 // Merge the (possibly disjoint) sets containing nodes "x" and 322 // "y". Returns OK if the all nodes in the union of these sets can 323 // be placed on the same device type. 324 // 325 // If this method returns an error, *this is unchanged. 326 Status ColocateNodes(const Node& x, const Node& y); 327 328 // This overload of ColocateNodes() allows a caller to provide the root node 329 // ids for the two nodes. For large graphs, this noticeably reduces the 330 // graph load time. 331 // If this method returns an error, *this is unchanged. 332 Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root); 333 334 void GetSoftDeviceCandidates(const Node& node, const Member& root_member, 335 int root_id, 336 std::vector<Device*>* possible_devices); 337 338 Status InitializeMembers(); 339 340 Status InitializeMemberWithAssignedDevice(const string& assigned_device_name, 341 const string& node_type, 342 Member* member); 343 344 Status InitializeMember(const Node& node, Member* member); 345 346 // Returns the root node of the disjoint tree to which the node with the 347 // given id is connected. 348 // FindRoot should be called only for debugging or after the members have 349 // been updated with direct root pointers because it does not update 350 // root pointers and can traverse many links. It exists to have 351 // a const version of FindAndUpdateRoot FindRoot(int node_id)352 int FindRoot(int node_id) const { 353 return Member::FindRoot(members_, node_id); 354 } 355 356 const Graph& graph_; 357 const FunctionStack stack_; 358 const FunctionLibraryDefinition& flib_def_; 359 std::vector<Member> members_; 360 InspectingPlacer inspecting_placer_; 361 PlacerInspectionRequiredOpChecker inspection_required_checker_; 362 const DeviceSet& device_set_; 363 const std::vector<DeviceType> device_types_; 364 const DeviceNameUtils::ParsedName local_address_spec_; 365 const Device* default_local_device_; 366 const bool allow_soft_placement_; 367 const bool log_device_placement_; 368 369 TF_DISALLOW_COPY_AND_ASSIGN(ColocationGraph); 370 }; 371 372 } // namespace tensorflow 373 374 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ 375