1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
37 #include "tensorflow/stream_executor/host/host_platform_id.h"
38 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
39
40 using absl::StrAppend;
41 using absl::StrCat;
42
43 namespace xla {
44
LogicalIdsForDevice(GlobalDeviceId device_id) const45 StatusOr<std::pair<int, int>> DeviceAssignment::LogicalIdsForDevice(
46 GlobalDeviceId device_id) const {
47 absl::optional<std::pair<int, int>> logical_ids;
48 for (int r = 0; r < replica_count(); ++r) {
49 for (int c = 0; c < computation_count(); ++c) {
50 if ((*this)(r, c) == device_id.value()) {
51 if (logical_ids.has_value()) {
52 return InternalError(
53 "Device %d appears twice in DeviceAssignment: %s",
54 device_id.value(), ToString());
55 }
56 logical_ids.emplace(r, c);
57 }
58 }
59 }
60 if (logical_ids.has_value()) {
61 return *logical_ids;
62 } else {
63 return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
64 device_id.value(), ToString());
65 }
66 }
67
ReplicaIdForDevice(GlobalDeviceId device_id) const68 StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
69 GlobalDeviceId device_id) const {
70 TF_ASSIGN_OR_RETURN(auto logical_ids, LogicalIdsForDevice(device_id));
71 return logical_ids.first;
72 }
73
Serialize(DeviceAssignmentProto * proto) const74 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
75 proto->set_replica_count(replica_count());
76 proto->set_computation_count(computation_count());
77 for (int computation = 0; computation < computation_count(); ++computation) {
78 DeviceAssignmentProto::ComputationDevice* computation_device =
79 proto->add_computation_devices();
80 for (int replica = 0; replica < replica_count(); ++replica) {
81 computation_device->add_replica_device_ids((*this)(replica, computation));
82 }
83 }
84 return Status::OK();
85 }
86
87 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)88 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
89 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
90 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
91 return InvalidArgument(
92 "Invalid device assignment topology: replica_count=%d, "
93 "computation_count=%d",
94 proto.replica_count(), proto.computation_count());
95 }
96 auto assignment = absl::make_unique<DeviceAssignment>(
97 proto.replica_count(), proto.computation_count());
98 for (int computation = 0; computation < proto.computation_count();
99 ++computation) {
100 const auto& computation_device = proto.computation_devices(computation);
101 TF_RET_CHECK(computation_device.replica_device_ids_size() ==
102 proto.replica_count());
103 for (int replica = 0; replica < proto.replica_count(); ++replica) {
104 (*assignment)(replica, computation) =
105 computation_device.replica_device_ids(replica);
106 }
107 }
108 return std::move(assignment);
109 }
110
ToString() const111 string DeviceAssignment::ToString() const {
112 string output = StrCat("Computations: ", computation_count(),
113 " Replicas: ", replica_count(), "\n");
114 for (int computation = 0; computation < computation_count(); ++computation) {
115 StrAppend(&output, "Computation ", computation, ": ");
116 for (int replica = 0; replica < replica_count(); ++replica) {
117 StrAppend(&output, operator()(replica, computation), " ");
118 }
119 StrAppend(&output, "\n");
120 }
121 return output;
122 }
123
DeviceId(int replica,int computation,int replica_count,int computation_count)124 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
125 int replica_count,
126 int computation_count) {
127 TF_RET_CHECK(replica < replica_count);
128 TF_RET_CHECK(computation < computation_count);
129
130 return computation * replica_count + replica;
131 }
132
AssignDevices(int replica_count,int computation_count)133 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
134 int replica_count, int computation_count) {
135 DeviceAssignment assignment(replica_count, computation_count);
136 for (int replica = 0; replica < replica_count; ++replica) {
137 for (int computation = 0; computation < computation_count; ++computation) {
138 TF_ASSIGN_OR_RETURN(
139 int device_id,
140 DeviceId(replica, computation, replica_count, computation_count));
141 assignment(replica, computation) = device_id;
142 }
143 }
144 return std::move(assignment);
145 }
146
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)147 /* static */ void ComputationPlacer::RegisterComputationPlacer(
148 se::Platform::Id platform_id,
149 ComputationPlacerCreationFunction creation_function) {
150 tensorflow::mutex_lock lock(
151 ComputationPlacer::platform_computation_placer_mutex_);
152 auto* computation_placers = GetPlatformComputationPlacers();
153 CHECK(computation_placers->find(platform_id) == computation_placers->end());
154 (*computation_placers)[platform_id].creation_function = creation_function;
155 }
156
GetForPlatform(const se::Platform * platform)157 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
158 const se::Platform* platform) {
159 tensorflow::mutex_lock lock(
160 ComputationPlacer::platform_computation_placer_mutex_);
161 auto* computation_placers = GetPlatformComputationPlacers();
162
163 auto it = computation_placers->find(platform->id());
164 if (it == computation_placers->end()) {
165 return NotFound(
166 "could not find registered computation placer for platform %s -- check "
167 "target linkage",
168 platform->Name());
169 }
170
171 if (it->second.placer == nullptr) {
172 // Lazily create the computation placer the first time it is needed.
173 it->second.placer = (*it->second.creation_function)();
174 }
175
176 return it->second.placer.get();
177 }
178
179 /* static */ tensorflow::mutex
180 ComputationPlacer::platform_computation_placer_mutex_(
181 tensorflow::LINKER_INITIALIZED);
182
183 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()184 ComputationPlacer::GetPlatformComputationPlacers() {
185 static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
186 return r;
187 }
188
189 } // namespace xla
190
CreateComputationPlacer()191 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
192 return absl::make_unique<xla::ComputationPlacer>();
193 }
194
InitModule()195 static bool InitModule() {
196 xla::ComputationPlacer::RegisterComputationPlacer(
197 stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
198 xla::ComputationPlacer::RegisterComputationPlacer(
199 stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
200 xla::ComputationPlacer::RegisterComputationPlacer(
201 stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
202 return true;
203 }
204 static bool module_initialized = InitModule();
205