• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/lib/core/arena.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/stringpiece.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class DeviceAttributes;
35 
36 // Represents a set of devices.
37 class DeviceMgr {
38  public:
39   DeviceMgr() = default;
40   virtual ~DeviceMgr();
41 
42   // Returns attributes of all devices.
43   virtual void ListDeviceAttributes(
44       std::vector<DeviceAttributes>* devices) const = 0;
45 
46   // Returns raw pointers to the underlying devices.
47   virtual std::vector<Device*> ListDevices() const = 0;
48 
49   // Returns a string listing all devices.
50   virtual string DebugString() const = 0;
51 
52   // Returns a string of all the device mapping.
53   virtual string DeviceMappingString() const = 0;
54 
55   // Assigns *device with pointer to Device of the given name.
56   // Accepts either a full device name, or just the replica-local suffix.
57   virtual Status LookupDevice(StringPiece name, Device** device) const = 0;
58 
59   // Clears given containers of all devices if 'container' is
60   // non-empty. Otherwise, clears default containers of all devices.
61   virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0;
62 
63   virtual int NumDeviceType(const string& type) const = 0;
64 
65   // Returns an arbitrary CPU device if one is present, otherwise return
66   // nullptr.
67   virtual Device* HostCPU() const = 0;
68 
69   TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
70 };
71 
72 // Represents a static set of devices.
73 class StaticDeviceMgr : public DeviceMgr {
74  public:
75   // Constructs a StaticDeviceMgr from a list of devices.
76   explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices);
77 
78   // Constructs a StaticDeviceMgr managing a single device.
79   explicit StaticDeviceMgr(std::unique_ptr<Device> device);
80 
81   ~StaticDeviceMgr() override;
82 
83   void ListDeviceAttributes(
84       std::vector<DeviceAttributes>* devices) const override;
85   std::vector<Device*> ListDevices() const override;
86   string DebugString() const override;
87   string DeviceMappingString() const override;
88   Status LookupDevice(StringPiece name, Device** device) const override;
89   void ClearContainers(gtl::ArraySlice<string> containers) const override;
90   int NumDeviceType(const string& type) const override;
91   Device* HostCPU() const override;
92 
93  private:
94   const std::vector<std::unique_ptr<Device>> devices_;
95 
96   StringPiece CopyToBackingStore(StringPiece s);
97 
98   std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_;
99   core::Arena name_backing_store_;  // Storage for keys in device_map_
100   std::unordered_map<string, int> device_type_counts_;
101   Device* cpu_device_;
102 
103   TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr);
104 };
105 
106 // Represents a dynamic set of devices
107 class DynamicDeviceMgr : public DeviceMgr {
108  public:
109   // Constructs an empty DynamicDeviceMgr.
110   DynamicDeviceMgr();
111 
112   ~DynamicDeviceMgr() override;
113 
114   void ListDeviceAttributes(
115       std::vector<DeviceAttributes>* devices) const override;
116   std::vector<Device*> ListDevices() const override;
117   string DebugString() const override;
118   string DeviceMappingString() const override;
119   Status LookupDevice(StringPiece name, Device** device) const override;
120   void ClearContainers(gtl::ArraySlice<string> containers) const override;
121   int NumDeviceType(const string& type) const override;
122   Device* HostCPU() const override;
123 
124   // Add devices to device manager. Returns error for repeated device names.
125   Status AddDevices(std::vector<std::unique_ptr<Device>> devices);
126 
127   // Remove devices from device manager.
128   // Returns error for non-existing devices or if the HostCPU() device is in the
129   // input list. If an error is returned, the device list is not modified.
130   Status RemoveDevices(std::vector<Device*> devices);
131 
132   // Remove devices from device manager by their names. Returns error for
133   // non-existing devices or if the HostCPU() device is given in the input list.
134   // If an error is returned, the device list is not modified.
135   Status RemoveDevicesByName(const std::vector<string>& device_names);
136 
137  private:
138   mutable mutex devices_mu_;
139 
140   std::unordered_map<Device*, std::unique_ptr<Device>> dynamic_devices_
141       GUARDED_BY(devices_mu_);
142 
143   std::unordered_map<string, Device*> device_map_ GUARDED_BY(devices_mu_);
144 
145   std::unordered_map<string, int> device_type_counts_ GUARDED_BY(devices_mu_);
146 
147   mutable Device* cpu_device_ GUARDED_BY(devices_mu_);
148 
149   TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr);
150 };
151 }  // namespace tensorflow
152 
153 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
154