1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/core/platform/status.h" 23 #include "tensorflow/core/platform/types.h" 24 25 namespace tensorflow { 26 27 class Device; 28 struct SessionOptions; 29 30 class DeviceFactory { 31 public: ~DeviceFactory()32 virtual ~DeviceFactory() {} 33 static void Register(const std::string& device_type, DeviceFactory* factory, 34 int priority, bool is_pluggable_device); 35 static DeviceFactory* GetFactory(const std::string& device_type); 36 37 // Append to "*devices" CPU devices. 38 static Status AddCpuDevices(const SessionOptions& options, 39 const std::string& name_prefix, 40 std::vector<std::unique_ptr<Device>>* devices); 41 42 // Append to "*devices" all suitable devices, respecting 43 // any device type specific properties/counts listed in "options". 44 // 45 // CPU devices are added first. 46 static Status AddDevices(const SessionOptions& options, 47 const std::string& name_prefix, 48 std::vector<std::unique_ptr<Device>>* devices); 49 50 // Helper for tests. Create a single device of type "type". The 51 // returned device is always numbered zero, so if creating multiple 52 // devices of the same type, supply distinct name_prefix arguments. 53 static std::unique_ptr<Device> NewDevice(const string& type, 54 const SessionOptions& options, 55 const string& name_prefix); 56 57 // Iterate through all device factories and build a list of all of the 58 // possible physical devices. 59 // 60 // CPU is are added first. 61 static Status ListAllPhysicalDevices(std::vector<string>* devices); 62 63 // Iterate through all device factories and build a list of all of the 64 // possible pluggable physical devices. 65 static Status ListPluggablePhysicalDevices(std::vector<string>* devices); 66 67 // Get details for a specific device among all device factories. 68 // 'device_index' indexes into devices from ListAllPhysicalDevices. 69 static Status GetAnyDeviceDetails( 70 int device_index, std::unordered_map<string, string>* details); 71 72 // For a specific device factory list all possible physical devices. 73 virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0; 74 75 // Get details for a specific device for a specific factory. Subclasses 76 // can store arbitrary device information in the map. 'device_index' indexes 77 // into devices from ListPhysicalDevices. GetDeviceDetails(int device_index,std::unordered_map<string,string> * details)78 virtual Status GetDeviceDetails(int device_index, 79 std::unordered_map<string, string>* details) { 80 return Status::OK(); 81 } 82 83 // Most clients should call AddDevices() instead. 84 virtual Status CreateDevices( 85 const SessionOptions& options, const std::string& name_prefix, 86 std::vector<std::unique_ptr<Device>>* devices) = 0; 87 88 // Return the device priority number for a "device_type" string. 89 // 90 // Higher number implies higher priority. 91 // 92 // In standard TensorFlow distributions, GPU device types are 93 // preferred over CPU, and by default, custom devices that don't set 94 // a custom priority during registration will be prioritized lower 95 // than CPU. Custom devices that want a higher priority can set the 96 // 'priority' field when registering their device to something 97 // higher than the packaged devices. See calls to 98 // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used 99 // for built-in devices. 100 static int32 DevicePriority(const std::string& device_type); 101 102 // Returns true if 'device_type' is registered from plugin. Returns false if 103 // 'device_type' is a first-party device. 104 static bool IsPluggableDevice(const std::string& device_type); 105 }; 106 107 namespace dfactory { 108 109 template <class Factory> 110 class Registrar { 111 public: 112 // Multiple registrations for the same device type with different priorities 113 // are allowed. Priorities are used in two different ways: 114 // 115 // 1) When choosing which factory (that is, which device 116 // implementation) to use for a specific 'device_type', the 117 // factory registered with the highest priority will be chosen. 118 // For example, if there are two registrations: 119 // 120 // Registrar<CPUFactory1>("CPU", 125); 121 // Registrar<CPUFactory2>("CPU", 150); 122 // 123 // then CPUFactory2 will be chosen when 124 // DeviceFactory::GetFactory("CPU") is called. 125 // 126 // 2) When choosing which 'device_type' is preferred over other 127 // DeviceTypes in a DeviceSet, the ordering is determined 128 // by the 'priority' set during registration. For example, if there 129 // are two registrations: 130 // 131 // Registrar<CPUFactory>("CPU", 100); 132 // Registrar<GPUFactory>("GPU", 200); 133 // 134 // then DeviceType("GPU") will be prioritized higher than 135 // DeviceType("CPU"). 136 // 137 // The default priority values for built-in devices is: 138 // GPU: 210 139 // GPUCompatibleCPU: 70 140 // ThreadPoolDevice: 60 141 // Default: 50 142 explicit Registrar(const std::string& device_type, int priority = 50) { 143 DeviceFactory::Register(device_type, new Factory(), priority, 144 /*is_pluggable_device*/ false); 145 } 146 }; 147 148 } // namespace dfactory 149 150 #define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ 151 INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ 152 __COUNTER__, ##__VA_ARGS__) 153 154 #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ 155 ctr, ...) \ 156 static ::tensorflow::dfactory::Registrar<device_factory> \ 157 INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \ 158 ##__VA_ARGS__) 159 160 // __COUNTER__ must go through another macro to be properly expanded 161 #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ 162 163 } // namespace tensorflow 164 165 #endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ 166