1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "tensorflow/core/common_runtime/device.h" 24 #include "tensorflow/core/platform/macros.h" 25 #include "tensorflow/core/platform/types.h" 26 #include "tensorflow/core/util/device_name_utils.h" 27 28 namespace tensorflow { 29 30 typedef std::vector<std::pair<Device*, int32>> PrioritizedDeviceVector; 31 32 // DeviceSet is a container class for managing the various types of 33 // devices used by a model. 34 class DeviceSet { 35 public: 36 DeviceSet(); 37 ~DeviceSet(); 38 39 // Does not take ownership of 'device'. 40 void AddDevice(Device* device) TF_LOCKS_EXCLUDED(devices_mu_); 41 42 // Set the device designated as the "client". This device 43 // must also be registered via AddDevice(). set_client_device(Device * device)44 void set_client_device(Device* device) { 45 DCHECK(client_device_ == nullptr); 46 client_device_ = device; 47 } 48 49 // Returns a pointer to the device designated as the "client". client_device()50 Device* client_device() const { return client_device_; } 51 52 // Return the list of devices in this set. devices()53 const std::vector<Device*>& devices() const { return devices_; } 54 55 // Given a DeviceNameUtils::ParsedName (which may have some 56 // wildcards for different components), fills "*devices" with all 57 // devices in "*this" that match "spec". 58 void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, 59 std::vector<Device*>* devices) const; 60 61 // Finds the device with the given "fullname". Returns nullptr if 62 // not found. 63 Device* FindDeviceByName(const string& fullname) const; 64 65 // Return the list of unique device types in this set, ordered 66 // with more preferable devices earlier. 67 std::vector<DeviceType> PrioritizedDeviceTypeList() const; 68 69 // Return the prioritized list of devices in this set. 70 // Devices are prioritized first by `DeviceTypeOrder`, then by name. 71 const PrioritizedDeviceVector& prioritized_devices() const 72 TF_LOCKS_EXCLUDED(devices_mu_); 73 74 // Return the prioritized list of unique device types in this set. 75 // 76 // The list will be ordered by decreasing priority. The priorities (the second 77 // element in the list's `std::pair<DeviceType, int32>`) will be initialized 78 // to the value of `DeviceTypeOrder` for the device types. 79 const PrioritizedDeviceTypeVector& prioritized_device_types() const 80 TF_LOCKS_EXCLUDED(devices_mu_); 81 82 // An order to sort by device types according to system-determined 83 // priority. 84 // 85 // Higher result implies higher priority. 86 static int DeviceTypeOrder(const DeviceType& d); 87 88 // Sorts a PrioritizedDeviceVector according to devices and explicit 89 // priorities. 90 // 91 // After a call to this function, the argument vector will be sorted by 92 // explicit priority (the second element in the `std::pair<DeviceType, 93 // int32>`), then by `DeviceTypeOrder` of the device type, then by device 94 // locality, and lastly by device name. 95 static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); 96 97 // Sorts a PrioritizedDeviceTypeVector according to types and explicit 98 // priorities. 99 // 100 // After a call to this function, the argument vector will be sorted by 101 // explicit priority (the second element in the `std::pair<DeviceType, 102 // int32>`), then by `DeviceTypeOrder` of the device type. 103 static void SortPrioritizedDeviceTypeVector( 104 PrioritizedDeviceTypeVector* vector); 105 106 private: 107 mutable mutex devices_mu_; 108 109 // Not owned. 110 std::vector<Device*> devices_; 111 112 // Cached prioritized vector, created on-the-fly when 113 // prioritized_devices() is called. 114 mutable PrioritizedDeviceVector prioritized_devices_ 115 TF_GUARDED_BY(devices_mu_); 116 117 // Cached prioritized vector, created on-the-fly when 118 // prioritized_device_types() is called. 119 mutable PrioritizedDeviceTypeVector prioritized_device_types_ 120 TF_GUARDED_BY(devices_mu_); 121 122 // Fullname -> device* for device in devices_. 123 std::unordered_map<string, Device*> device_by_name_; 124 125 // client_device_ points to an element of devices_ that we consider 126 // to be the client device (in this local process). 127 Device* client_device_ = nullptr; 128 129 TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet); 130 }; 131 132 } // namespace tensorflow 133 134 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ 135