1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
37 #include "tensorflow/stream_executor/host/host_platform_id.h"
38 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
39
40 using absl::StrAppend;
41 using absl::StrCat;
42
43 namespace xla {
44
LogicalIdForDevice(GlobalDeviceId device_id) const45 StatusOr<DeviceAssignment::LogicalID> DeviceAssignment::LogicalIdForDevice(
46 GlobalDeviceId device_id) const {
47 absl::optional<DeviceAssignment::LogicalID> logical_id;
48 for (int r = 0; r < replica_count(); ++r) {
49 for (int c = 0; c < computation_count(); ++c) {
50 if ((*this)(r, c) == device_id.value()) {
51 if (logical_id.has_value()) {
52 return InternalError(
53 "Device %d appears twice in DeviceAssignment: %s",
54 device_id.value(), ToString());
55 }
56 logical_id.emplace(DeviceAssignment::LogicalID{r, c});
57 }
58 }
59 }
60 if (logical_id.has_value()) {
61 return *logical_id;
62 } else {
63 return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
64 device_id.value(), ToString());
65 }
66 }
67
ReplicaIdForDevice(GlobalDeviceId device_id) const68 StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
69 GlobalDeviceId device_id) const {
70 TF_ASSIGN_OR_RETURN(const LogicalID logical_id,
71 LogicalIdForDevice(device_id));
72 return logical_id.replica_id;
73 }
74
Serialize(DeviceAssignmentProto * proto) const75 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
76 proto->set_replica_count(replica_count());
77 proto->set_computation_count(computation_count());
78 for (int computation = 0; computation < computation_count(); ++computation) {
79 DeviceAssignmentProto::ComputationDevice* computation_device =
80 proto->add_computation_devices();
81 for (int replica = 0; replica < replica_count(); ++replica) {
82 computation_device->add_replica_device_ids((*this)(replica, computation));
83 }
84 }
85 return Status::OK();
86 }
87
88 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)89 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
90 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
91 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
92 return InvalidArgument(
93 "Invalid device assignment topology: replica_count=%d, "
94 "computation_count=%d",
95 proto.replica_count(), proto.computation_count());
96 }
97 auto assignment = absl::make_unique<DeviceAssignment>(
98 proto.replica_count(), proto.computation_count());
99 for (int computation = 0; computation < proto.computation_count();
100 ++computation) {
101 const auto& computation_device = proto.computation_devices(computation);
102 TF_RET_CHECK(computation_device.replica_device_ids_size() ==
103 proto.replica_count());
104 for (int replica = 0; replica < proto.replica_count(); ++replica) {
105 (*assignment)(replica, computation) =
106 computation_device.replica_device_ids(replica);
107 }
108 }
109 return std::move(assignment);
110 }
111
ToString() const112 string DeviceAssignment::ToString() const {
113 string output = StrCat("Computations: ", computation_count(),
114 " Replicas: ", replica_count(), "\n");
115 for (int computation = 0; computation < computation_count(); ++computation) {
116 StrAppend(&output, "Computation ", computation, ": ");
117 for (int replica = 0; replica < replica_count(); ++replica) {
118 StrAppend(&output, operator()(replica, computation), " ");
119 }
120 StrAppend(&output, "\n");
121 }
122 return output;
123 }
124
DeviceId(int replica,int computation,int replica_count,int computation_count)125 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
126 int replica_count,
127 int computation_count) {
128 TF_RET_CHECK(replica < replica_count);
129 TF_RET_CHECK(computation < computation_count);
130
131 return computation * replica_count + replica;
132 }
133
AssignDevices(int replica_count,int computation_count)134 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
135 int replica_count, int computation_count) {
136 DeviceAssignment assignment(replica_count, computation_count);
137 for (int replica = 0; replica < replica_count; ++replica) {
138 for (int computation = 0; computation < computation_count; ++computation) {
139 TF_ASSIGN_OR_RETURN(
140 int device_id,
141 DeviceId(replica, computation, replica_count, computation_count));
142 assignment(replica, computation) = device_id;
143 }
144 }
145 return std::move(assignment);
146 }
147
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)148 /* static */ void ComputationPlacer::RegisterComputationPlacer(
149 se::Platform::Id platform_id,
150 ComputationPlacerCreationFunction creation_function) {
151 tensorflow::mutex_lock lock(
152 ComputationPlacer::platform_computation_placer_mutex_);
153 auto* computation_placers = GetPlatformComputationPlacers();
154 CHECK(computation_placers->find(platform_id) == computation_placers->end());
155 (*computation_placers)[platform_id].creation_function = creation_function;
156 }
157
GetForPlatform(const se::Platform * platform)158 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
159 const se::Platform* platform) {
160 tensorflow::mutex_lock lock(
161 ComputationPlacer::platform_computation_placer_mutex_);
162 auto* computation_placers = GetPlatformComputationPlacers();
163
164 auto it = computation_placers->find(platform->id());
165 if (it == computation_placers->end()) {
166 return NotFound(
167 "could not find registered computation placer for platform %s -- check "
168 "target linkage",
169 platform->Name());
170 }
171
172 if (it->second.placer == nullptr) {
173 // Lazily create the computation placer the first time it is needed.
174 it->second.placer = (*it->second.creation_function)();
175 }
176
177 return it->second.placer.get();
178 }
179
180 /* static */ tensorflow::mutex
181 ComputationPlacer::platform_computation_placer_mutex_(
182 tensorflow::LINKER_INITIALIZED);
183
184 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()185 ComputationPlacer::GetPlatformComputationPlacers() {
186 static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
187 return r;
188 }
189
190 } // namespace xla
191
CreateComputationPlacer()192 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
193 return absl::make_unique<xla::ComputationPlacer>();
194 }
195
InitModule()196 static bool InitModule() {
197 xla::ComputationPlacer::RegisterComputationPlacer(
198 stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
199 xla::ComputationPlacer::RegisterComputationPlacer(
200 stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
201 xla::ComputationPlacer::RegisterComputationPlacer(
202 stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
203 return true;
204 }
205 static bool module_initialized = InitModule();
206