1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/device_set.h"
17
18 #include <set>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26
27 namespace tensorflow {
28
DeviceSet()29 DeviceSet::DeviceSet() {}
30
~DeviceSet()31 DeviceSet::~DeviceSet() {}
32
AddDevice(Device * device)33 void DeviceSet::AddDevice(Device* device) {
34 devices_.push_back(device);
35 for (const string& name :
36 DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
37 device_by_name_.insert({name, device});
38 }
39 }
40
FindMatchingDevices(const DeviceNameUtils::ParsedName & spec,std::vector<Device * > * devices) const41 void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
42 std::vector<Device*>* devices) const {
43 // TODO(jeff): If we are going to repeatedly lookup the set of devices
44 // for the same spec, maybe we should have a cache of some sort
45 devices->clear();
46 for (Device* d : devices_) {
47 if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
48 devices->push_back(d);
49 }
50 }
51 }
52
FindDeviceByName(const string & name) const53 Device* DeviceSet::FindDeviceByName(const string& name) const {
54 return gtl::FindPtrOrNull(device_by_name_, name);
55 }
56
57 // static
DeviceTypeOrder(const DeviceType & d)58 int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
59 return DeviceFactory::DevicePriority(d.type_string());
60 }
61
DeviceTypeComparator(const DeviceType & a,const DeviceType & b)62 static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
63 // First sort by prioritized device type (higher is preferred) and
64 // then by device name (lexicographically).
65 auto a_priority = DeviceSet::DeviceTypeOrder(a);
66 auto b_priority = DeviceSet::DeviceTypeOrder(b);
67 if (a_priority != b_priority) {
68 return a_priority > b_priority;
69 }
70
71 return StringPiece(a.type()) < StringPiece(b.type());
72 }
73
PrioritizedDeviceTypeList() const74 std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
75 std::vector<DeviceType> result;
76 std::set<string> seen;
77 for (Device* d : devices_) {
78 const auto& t = d->device_type();
79 if (seen.insert(t).second) {
80 result.emplace_back(t);
81 }
82 }
83 std::sort(result.begin(), result.end(), DeviceTypeComparator);
84 return result;
85 }
86
87 } // namespace tensorflow
88