• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/device_factory.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/device.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/public/session_options.h"
30 
31 namespace tensorflow {
32 
33 namespace {
34 
get_device_factory_lock()35 static mutex* get_device_factory_lock() {
36   static mutex device_factory_lock(LINKER_INITIALIZED);
37   return &device_factory_lock;
38 }
39 
40 struct FactoryItem {
41   std::unique_ptr<DeviceFactory> factory;
42   int priority;
43 };
44 
device_factories()45 std::unordered_map<string, FactoryItem>& device_factories() {
46   static std::unordered_map<string, FactoryItem>* factories =
47       new std::unordered_map<string, FactoryItem>;
48   return *factories;
49 }
50 
51 }  // namespace
52 
53 // static
DevicePriority(const string & device_type)54 int32 DeviceFactory::DevicePriority(const string& device_type) {
55   tf_shared_lock l(*get_device_factory_lock());
56   std::unordered_map<string, FactoryItem>& factories = device_factories();
57   auto iter = factories.find(device_type);
58   if (iter != factories.end()) {
59     return iter->second.priority;
60   }
61 
62   return -1;
63 }
64 
65 // static
Register(const string & device_type,DeviceFactory * factory,int priority)66 void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
67                              int priority) {
68   mutex_lock l(*get_device_factory_lock());
69   std::unique_ptr<DeviceFactory> factory_ptr(factory);
70   std::unordered_map<string, FactoryItem>& factories = device_factories();
71   auto iter = factories.find(device_type);
72   if (iter == factories.end()) {
73     factories[device_type] = {std::move(factory_ptr), priority};
74   } else {
75     if (iter->second.priority < priority) {
76       iter->second = {std::move(factory_ptr), priority};
77     } else if (iter->second.priority == priority) {
78       LOG(FATAL) << "Duplicate registration of device factory for type "
79                  << device_type << " with the same priority " << priority;
80     }
81   }
82 }
83 
GetFactory(const string & device_type)84 DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
85   tf_shared_lock l(*get_device_factory_lock());
86   auto it = device_factories().find(device_type);
87   if (it == device_factories().end()) {
88     return nullptr;
89   }
90   return it->second.factory.get();
91 }
92 
ListAllPhysicalDevices(std::vector<string> * devices)93 Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
94   // CPU first. A CPU device is required.
95   auto cpu_factory = GetFactory("CPU");
96   if (!cpu_factory) {
97     return errors::NotFound(
98         "CPU Factory not registered. Did you link in threadpool_device?");
99   }
100 
101   size_t init_size = devices->size();
102   TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices));
103   if (devices->size() == init_size) {
104     return errors::NotFound("No CPU devices are available in this process");
105   }
106 
107   // Then the rest (including GPU).
108   tf_shared_lock l(*get_device_factory_lock());
109   for (auto& p : device_factories()) {
110     auto factory = p.second.factory.get();
111     if (factory != cpu_factory) {
112       TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
113     }
114   }
115 
116   return Status::OK();
117 }
118 
GetAnyDeviceDetails(int device_index,std::unordered_map<string,string> * details)119 Status DeviceFactory::GetAnyDeviceDetails(
120     int device_index, std::unordered_map<string, string>* details) {
121   if (device_index < 0) {
122     return errors::InvalidArgument("Device index out of bounds: ",
123                                    device_index);
124   }
125   const int orig_device_index = device_index;
126 
127   // Iterate over devices in the same way as in ListAllPhysicalDevices.
128   auto cpu_factory = GetFactory("CPU");
129   if (!cpu_factory) {
130     return errors::NotFound(
131         "CPU Factory not registered. Did you link in threadpool_device?");
132   }
133 
134   std::vector<string> devices;
135   TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
136   if (device_index < devices.size()) {
137     return cpu_factory->GetDeviceDetails(device_index, details);
138   }
139   device_index -= devices.size();
140 
141   // Then the rest (including GPU).
142   tf_shared_lock l(*get_device_factory_lock());
143   for (auto& p : device_factories()) {
144     auto factory = p.second.factory.get();
145     if (factory != cpu_factory) {
146       devices.clear();
147       // TODO(b/146009447): Find the factory size without having to allocate a
148       // vector with all the physical devices.
149       TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
150       if (device_index < devices.size()) {
151         return factory->GetDeviceDetails(device_index, details);
152       }
153       device_index -= devices.size();
154     }
155   }
156 
157   return errors::InvalidArgument("Device index out of bounds: ",
158                                  orig_device_index);
159 }
160 
AddDevices(const SessionOptions & options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)161 Status DeviceFactory::AddDevices(
162     const SessionOptions& options, const string& name_prefix,
163     std::vector<std::unique_ptr<Device>>* devices) {
164   // CPU first. A CPU device is required.
165   auto cpu_factory = GetFactory("CPU");
166   if (!cpu_factory) {
167     return errors::NotFound(
168         "CPU Factory not registered. Did you link in threadpool_device?");
169   }
170   size_t init_size = devices->size();
171   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices));
172   if (devices->size() == init_size) {
173     return errors::NotFound("No CPU devices are available in this process");
174   }
175 
176   // Then the rest (including GPU).
177   mutex_lock l(*get_device_factory_lock());
178   for (auto& p : device_factories()) {
179     auto factory = p.second.factory.get();
180     if (factory != cpu_factory) {
181       TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
182     }
183   }
184 
185   return Status::OK();
186 }
187 
NewDevice(const string & type,const SessionOptions & options,const string & name_prefix)188 std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
189                                                  const SessionOptions& options,
190                                                  const string& name_prefix) {
191   auto device_factory = GetFactory(type);
192   if (!device_factory) {
193     return nullptr;
194   }
195   SessionOptions opt = options;
196   (*opt.config.mutable_device_count())[type] = 1;
197   std::vector<std::unique_ptr<Device>> devices;
198   TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
199   int expected_num_devices = 1;
200   auto iter = options.config.device_count().find(type);
201   if (iter != options.config.device_count().end()) {
202     expected_num_devices = iter->second;
203   }
204   DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
205   return std::move(devices[0]);
206 }
207 
208 }  // namespace tensorflow
209