1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/lib/core/arena.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/stringpiece.h" 29 #include "tensorflow/core/lib/gtl/inlined_vector.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class DeviceAttributes; 35 36 // Represents a set of devices. 37 class DeviceMgr { 38 public: 39 DeviceMgr() = default; 40 virtual ~DeviceMgr(); 41 42 // Returns attributes of all devices. 43 virtual void ListDeviceAttributes( 44 std::vector<DeviceAttributes>* devices) const = 0; 45 46 // Returns raw pointers to the underlying devices. 47 virtual std::vector<Device*> ListDevices() const = 0; 48 49 // Returns a string listing all devices. 50 virtual string DebugString() const = 0; 51 52 // Returns a string of all the device mapping. 53 virtual string DeviceMappingString() const = 0; 54 55 // Assigns *device with pointer to Device of the given name. 56 // Accepts either a full device name, or just the replica-local suffix. 57 virtual Status LookupDevice(StringPiece name, Device** device) const = 0; 58 59 // Clears given containers of all devices if 'container' is 60 // non-empty. Otherwise, clears default containers of all devices. 61 virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0; 62 63 virtual int NumDeviceType(const string& type) const = 0; 64 65 // Returns an arbitrary CPU device if one is present, otherwise return 66 // nullptr. 67 virtual Device* HostCPU() const = 0; 68 69 TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); 70 }; 71 72 // Represents a static set of devices. 73 class StaticDeviceMgr : public DeviceMgr { 74 public: 75 // Constructs a StaticDeviceMgr from a list of devices. 76 explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices); 77 78 // Constructs a StaticDeviceMgr managing a single device. 79 explicit StaticDeviceMgr(std::unique_ptr<Device> device); 80 81 ~StaticDeviceMgr() override; 82 83 void ListDeviceAttributes( 84 std::vector<DeviceAttributes>* devices) const override; 85 std::vector<Device*> ListDevices() const override; 86 string DebugString() const override; 87 string DeviceMappingString() const override; 88 Status LookupDevice(StringPiece name, Device** device) const override; 89 void ClearContainers(gtl::ArraySlice<string> containers) const override; 90 int NumDeviceType(const string& type) const override; 91 Device* HostCPU() const override; 92 93 private: 94 const std::vector<std::unique_ptr<Device>> devices_; 95 96 StringPiece CopyToBackingStore(StringPiece s); 97 98 std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_; 99 core::Arena name_backing_store_; // Storage for keys in device_map_ 100 std::unordered_map<string, int> device_type_counts_; 101 Device* cpu_device_; 102 103 TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr); 104 }; 105 106 // Represents a dynamic set of devices 107 class DynamicDeviceMgr : public DeviceMgr { 108 public: 109 // Constructs an empty DynamicDeviceMgr. 110 DynamicDeviceMgr(); 111 112 ~DynamicDeviceMgr() override; 113 114 void ListDeviceAttributes( 115 std::vector<DeviceAttributes>* devices) const override; 116 std::vector<Device*> ListDevices() const override; 117 string DebugString() const override; 118 string DeviceMappingString() const override; 119 Status LookupDevice(StringPiece name, Device** device) const override; 120 void ClearContainers(gtl::ArraySlice<string> containers) const override; 121 int NumDeviceType(const string& type) const override; 122 Device* HostCPU() const override; 123 124 // Add devices to device manager. Returns error for repeated device names. 125 Status AddDevices(std::vector<std::unique_ptr<Device>> devices); 126 127 // Remove devices from device manager. 128 // Returns error for non-existing devices or if the HostCPU() device is in the 129 // input list. If an error is returned, the device list is not modified. 130 Status RemoveDevices(std::vector<Device*> devices); 131 132 // Remove devices from device manager by their names. Returns error for 133 // non-existing devices or if the HostCPU() device is given in the input list. 134 // If an error is returned, the device list is not modified. 135 Status RemoveDevicesByName(const std::vector<string>& device_names); 136 137 private: 138 mutable mutex devices_mu_; 139 140 std::unordered_map<Device*, std::unique_ptr<Device>> dynamic_devices_ 141 GUARDED_BY(devices_mu_); 142 143 std::unordered_map<string, Device*> device_map_ GUARDED_BY(devices_mu_); 144 145 std::unordered_map<string, int> device_type_counts_ GUARDED_BY(devices_mu_); 146 147 mutable Device* cpu_device_ GUARDED_BY(devices_mu_); 148 149 TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr); 150 }; 151 } // namespace tensorflow 152 153 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ 154