• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/util/device_name_utils.h"
27 
28 namespace tensorflow {
29 
30 typedef std::vector<std::pair<Device*, int32>> PrioritizedDeviceVector;
31 
32 // DeviceSet is a container class for managing the various types of
33 // devices used by a model.
34 class DeviceSet {
35  public:
36   DeviceSet();
37   ~DeviceSet();
38 
39   // Does not take ownership of 'device'.
40   void AddDevice(Device* device) TF_LOCKS_EXCLUDED(devices_mu_);
41 
42   // Set the device designated as the "client".  This device
43   // must also be registered via AddDevice().
set_client_device(Device * device)44   void set_client_device(Device* device) {
45     DCHECK(client_device_ == nullptr);
46     client_device_ = device;
47   }
48 
49   // Returns a pointer to the device designated as the "client".
client_device()50   Device* client_device() const { return client_device_; }
51 
52   // Return the list of devices in this set.
devices()53   const std::vector<Device*>& devices() const { return devices_; }
54 
55   // Given a DeviceNameUtils::ParsedName (which may have some
56   // wildcards for different components), fills "*devices" with all
57   // devices in "*this" that match "spec".
58   void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
59                            std::vector<Device*>* devices) const;
60 
61   // Finds the device with the given "fullname". Returns nullptr if
62   // not found.
63   Device* FindDeviceByName(const string& fullname) const;
64 
65   // Return the list of unique device types in this set, ordered
66   // with more preferable devices earlier.
67   std::vector<DeviceType> PrioritizedDeviceTypeList() const;
68 
69   // Return the prioritized list of devices in this set.
70   // Devices are prioritized first by `DeviceTypeOrder`, then by name.
71   const PrioritizedDeviceVector& prioritized_devices() const
72       TF_LOCKS_EXCLUDED(devices_mu_);
73 
74   // Return the prioritized list of unique device types in this set.
75   //
76   // The list will be ordered by decreasing priority. The priorities (the second
77   // element in the list's `std::pair<DeviceType, int32>`) will be initialized
78   // to the value of `DeviceTypeOrder` for the device types.
79   const PrioritizedDeviceTypeVector& prioritized_device_types() const
80       TF_LOCKS_EXCLUDED(devices_mu_);
81 
82   // An order to sort by device types according to system-determined
83   // priority.
84   //
85   // Higher result implies higher priority.
86   static int DeviceTypeOrder(const DeviceType& d);
87 
88   // Sorts a PrioritizedDeviceVector according to devices and explicit
89   // priorities.
90   //
91   // After a call to this function, the argument vector will be sorted by
92   // explicit priority (the second element in the `std::pair<DeviceType,
93   // int32>`), then by `DeviceTypeOrder` of the device type, then by device
94   // locality, and lastly by device name.
95   static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector);
96 
97   // Sorts a PrioritizedDeviceTypeVector according to types and explicit
98   // priorities.
99   //
100   // After a call to this function, the argument vector will be sorted by
101   // explicit priority (the second element in the `std::pair<DeviceType,
102   // int32>`), then by `DeviceTypeOrder` of the device type.
103   static void SortPrioritizedDeviceTypeVector(
104       PrioritizedDeviceTypeVector* vector);
105 
106  private:
107   mutable mutex devices_mu_;
108 
109   // Not owned.
110   std::vector<Device*> devices_;
111 
112   // Cached prioritized vector, created on-the-fly when
113   // prioritized_devices() is called.
114   mutable PrioritizedDeviceVector prioritized_devices_
115       TF_GUARDED_BY(devices_mu_);
116 
117   // Cached prioritized vector, created on-the-fly when
118   // prioritized_device_types() is called.
119   mutable PrioritizedDeviceTypeVector prioritized_device_types_
120       TF_GUARDED_BY(devices_mu_);
121 
122   // Fullname -> device* for device in devices_.
123   std::unordered_map<string, Device*> device_by_name_;
124 
125   // client_device_ points to an element of devices_ that we consider
126   // to be the client device (in this local process).
127   Device* client_device_ = nullptr;
128 
129   TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
130 };
131 
132 }  // namespace tensorflow
133 
134 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
135