1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/device_set.h"
17
18 #include <set>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26
27 namespace tensorflow {
28
DeviceSet()29 DeviceSet::DeviceSet() {}
30
~DeviceSet()31 DeviceSet::~DeviceSet() {}
32
AddDevice(Device * device)33 void DeviceSet::AddDevice(Device* device) {
34 mutex_lock l(devices_mu_);
35 devices_.push_back(device);
36 prioritized_devices_.clear();
37 prioritized_device_types_.clear();
38 for (const string& name :
39 DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
40 device_by_name_.insert({name, device});
41 }
42 }
43
FindMatchingDevices(const DeviceNameUtils::ParsedName & spec,std::vector<Device * > * devices) const44 void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
45 std::vector<Device*>* devices) const {
46 // TODO(jeff): If we are going to repeatedly lookup the set of devices
47 // for the same spec, maybe we should have a cache of some sort
48 devices->clear();
49 for (Device* d : devices_) {
50 if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
51 devices->push_back(d);
52 }
53 }
54 }
55
FindDeviceByName(const string & name) const56 Device* DeviceSet::FindDeviceByName(const string& name) const {
57 return gtl::FindPtrOrNull(device_by_name_, name);
58 }
59
60 // static
DeviceTypeOrder(const DeviceType & d)61 int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
62 return DeviceFactory::DevicePriority(d.type_string());
63 }
64
DeviceTypeComparator(const DeviceType & a,const DeviceType & b)65 static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
66 // First sort by prioritized device type (higher is preferred) and
67 // then by device name (lexicographically).
68 auto a_priority = DeviceSet::DeviceTypeOrder(a);
69 auto b_priority = DeviceSet::DeviceTypeOrder(b);
70 if (a_priority != b_priority) {
71 return a_priority > b_priority;
72 }
73
74 return StringPiece(a.type()) < StringPiece(b.type());
75 }
76
PrioritizedDeviceTypeList() const77 std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
78 std::vector<DeviceType> result;
79 std::set<string> seen;
80 for (Device* d : devices_) {
81 const auto& t = d->device_type();
82 if (seen.insert(t).second) {
83 result.emplace_back(t);
84 }
85 }
86 std::sort(result.begin(), result.end(), DeviceTypeComparator);
87 return result;
88 }
89
SortPrioritizedDeviceTypeVector(PrioritizedDeviceTypeVector * vector)90 void DeviceSet::SortPrioritizedDeviceTypeVector(
91 PrioritizedDeviceTypeVector* vector) {
92 if (vector == nullptr) return;
93
94 auto device_sort = [](const PrioritizedDeviceTypeVector::value_type& a,
95 const PrioritizedDeviceTypeVector::value_type& b) {
96 // First look at set priorities.
97 if (a.second != b.second) {
98 return a.second > b.second;
99 }
100 // Then fallback to default priorities.
101 return DeviceTypeComparator(a.first, b.first);
102 };
103
104 std::sort(vector->begin(), vector->end(), device_sort);
105 }
106
SortPrioritizedDeviceVector(PrioritizedDeviceVector * vector)107 void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) {
108 auto device_sort = [](const std::pair<Device*, int32>& a,
109 const std::pair<Device*, int32>& b) {
110 if (a.second != b.second) {
111 return a.second > b.second;
112 }
113
114 const string& a_type_name = a.first->device_type();
115 const string& b_type_name = b.first->device_type();
116 if (a_type_name != b_type_name) {
117 auto a_priority = DeviceFactory::DevicePriority(a_type_name);
118 auto b_priority = DeviceFactory::DevicePriority(b_type_name);
119 if (a_priority != b_priority) {
120 return a_priority > b_priority;
121 }
122 }
123
124 if (a.first->IsLocal() != b.first->IsLocal()) {
125 return a.first->IsLocal();
126 }
127
128 return StringPiece(a.first->name()) < StringPiece(b.first->name());
129 };
130 std::sort(vector->begin(), vector->end(), device_sort);
131 }
132
133 namespace {
134
UpdatePrioritizedVectors(const std::vector<Device * > & devices,PrioritizedDeviceVector * prioritized_devices,PrioritizedDeviceTypeVector * prioritized_device_types)135 void UpdatePrioritizedVectors(
136 const std::vector<Device*>& devices,
137 PrioritizedDeviceVector* prioritized_devices,
138 PrioritizedDeviceTypeVector* prioritized_device_types) {
139 if (prioritized_devices->size() != devices.size()) {
140 for (Device* d : devices) {
141 prioritized_devices->emplace_back(
142 d, DeviceSet::DeviceTypeOrder(DeviceType(d->device_type())));
143 }
144 DeviceSet::SortPrioritizedDeviceVector(prioritized_devices);
145 }
146
147 if (prioritized_device_types != nullptr &&
148 prioritized_device_types->size() != devices.size()) {
149 std::set<DeviceType> seen;
150 for (const std::pair<Device*, int32>& p : *prioritized_devices) {
151 DeviceType t(p.first->device_type());
152 if (seen.insert(t).second) {
153 prioritized_device_types->emplace_back(t, p.second);
154 }
155 }
156 }
157 }
158
159 } // namespace
160
prioritized_devices() const161 const PrioritizedDeviceVector& DeviceSet::prioritized_devices() const {
162 mutex_lock l(devices_mu_);
163 UpdatePrioritizedVectors(devices_, &prioritized_devices_,
164 /* prioritized_device_types */ nullptr);
165 return prioritized_devices_;
166 }
167
prioritized_device_types() const168 const PrioritizedDeviceTypeVector& DeviceSet::prioritized_device_types() const {
169 mutex_lock l(devices_mu_);
170 UpdatePrioritizedVectors(devices_, &prioritized_devices_,
171 &prioritized_device_types_);
172 return prioritized_device_types_;
173 }
174
175 } // namespace tensorflow
176