• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/platform/types.h"
24 
25 namespace tensorflow {
26 
27 class Device;
28 struct SessionOptions;
29 
30 class DeviceFactory {
31  public:
~DeviceFactory()32   virtual ~DeviceFactory() {}
33   static void Register(const string& device_type, DeviceFactory* factory,
34                        int priority);
35   static DeviceFactory* GetFactory(const string& device_type);
36 
37   // Append to "*devices" all suitable devices, respecting
38   // any device type specific properties/counts listed in "options".
39   //
40   // CPU devices are added first.
41   static Status AddDevices(const SessionOptions& options,
42                            const string& name_prefix,
43                            std::vector<std::unique_ptr<Device>>* devices);
44 
45   // Helper for tests.  Create a single device of type "type".  The
46   // returned device is always numbered zero, so if creating multiple
47   // devices of the same type, supply distinct name_prefix arguments.
48   static std::unique_ptr<Device> NewDevice(const string& type,
49                                            const SessionOptions& options,
50                                            const string& name_prefix);
51 
52   // Most clients should call AddDevices() instead.
53   virtual Status CreateDevices(
54       const SessionOptions& options, const string& name_prefix,
55       std::vector<std::unique_ptr<Device>>* devices) = 0;
56 
57   // Return the device priority number for a "device_type" string.
58   //
59   // Higher number implies higher priority.
60   //
61   // In standard TensorFlow distributions, GPU device types are
62   // preferred over CPU, and by default, custom devices that don't set
63   // a custom priority during registration will be prioritized lower
64   // than CPU.  Custom devices that want a higher priority can set the
65   // 'priority' field when registering their device to something
66   // higher than the packaged devices.  See calls to
67   // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
68   // for built-in devices.
69   static int32 DevicePriority(const string& device_type);
70 };
71 
72 namespace dfactory {
73 
74 template <class Factory>
75 class Registrar {
76  public:
77   // Multiple registrations for the same device type with different priorities
78   // are allowed.  Priorities are used in two different ways:
79   //
80   // 1) When choosing which factory (that is, which device
81   //    implementation) to use for a specific 'device_type', the
82   //    factory registered with the highest priority will be chosen.
83   //    For example, if there are two registrations:
84   //
85   //      Registrar<CPUFactory1>("CPU", 125);
86   //      Registrar<CPUFactory2>("CPU", 150);
87   //
88   //    then CPUFactory2 will be chosen when
89   //    DeviceFactory::GetFactory("CPU") is called.
90   //
91   // 2) When choosing which 'device_type' is preferred over other
92   //    DeviceTypes in a DeviceSet, the ordering is determined
93   //    by the 'priority' set during registration.  For example, if there
94   //    are two registrations:
95   //
96   //      Registrar<CPUFactory>("CPU", 100);
97   //      Registrar<GPUFactory>("GPU", 200);
98   //
99   //    then DeviceType("GPU") will be prioritized higher than
100   //    DeviceType("CPU").
101   //
102   // The default priority values for built-in devices is:
103   // GPU: 210
104   // SYCL: 200
105   // GPUCompatibleCPU: 70
106   // ThreadPoolDevice: 60
107   // Default: 50
108   explicit Registrar(const string& device_type, int priority = 50) {
109     DeviceFactory::Register(device_type, new Factory(), priority);
110   }
111 };
112 
113 }  // namespace dfactory
114 
115 #define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
116   INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory,   \
117                                          __COUNTER__, ##__VA_ARGS__)
118 
119 #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
120                                                ctr, ...)                    \
121   static ::tensorflow::dfactory::Registrar<device_factory>                  \
122       INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type,         \
123                                                        ##__VA_ARGS__)
124 
125 // __COUNTER__ must go through another macro to be properly expanded
126 #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
127 
128 }  // namespace tensorflow
129 
130 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
131