1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
36
37 using absl::StrAppend;
38 using absl::StrCat;
39
40 namespace xla {
41
ReplicaIdForDeviceOrdinal(int device_ordinal) const42 StatusOr<int> DeviceAssignment::ReplicaIdForDeviceOrdinal(
43 int device_ordinal) const {
44 absl::optional<int> replica_id;
45 for (int64 r = 0; r < replica_count(); ++r) {
46 for (int64 c = 0; c < computation_count(); ++c) {
47 if ((*this)(r, c) == device_ordinal) {
48 if (replica_id.has_value()) {
49 return InternalError(
50 "Device ordinal %d appears twice in DeviceAssignment? %s",
51 device_ordinal, ToString());
52 }
53 replica_id = r;
54 }
55 }
56 }
57 if (!replica_id.has_value()) {
58 return InternalError(
59 "Device ordinal %d doesn't appear in DeviceAssignment %s",
60 device_ordinal, ToString());
61 }
62 return *replica_id;
63 }
64
Serialize(DeviceAssignmentProto * proto) const65 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
66 proto->set_replica_count(replica_count());
67 proto->set_computation_count(computation_count());
68 for (int computation = 0; computation < computation_count(); ++computation) {
69 DeviceAssignmentProto::ComputationDevice* computation_device =
70 proto->add_computation_devices();
71 for (int replica = 0; replica < replica_count(); ++replica) {
72 computation_device->add_replica_device_ids((*this)(replica, computation));
73 }
74 }
75 return Status::OK();
76 }
77
78 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)79 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
80 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
81 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
82 return InvalidArgument(
83 "Invalid device assignment topology: replica_count=%d, "
84 "computation_count=%d",
85 proto.replica_count(), proto.computation_count());
86 }
87 auto assignment = absl::make_unique<DeviceAssignment>(
88 proto.replica_count(), proto.computation_count());
89 for (int computation = 0; computation < proto.computation_count();
90 ++computation) {
91 const auto& computation_device = proto.computation_devices(computation);
92 TF_RET_CHECK(computation_device.replica_device_ids_size() ==
93 proto.replica_count());
94 for (int replica = 0; replica < proto.replica_count(); ++replica) {
95 (*assignment)(replica, computation) =
96 computation_device.replica_device_ids(replica);
97 }
98 }
99 return std::move(assignment);
100 }
101
ToString() const102 string DeviceAssignment::ToString() const {
103 string output = StrCat("Computations: ", computation_count(),
104 " Replicas: ", replica_count(), "\n");
105 for (int computation = 0; computation < computation_count(); ++computation) {
106 StrAppend(&output, "Computation ", computation, ": ");
107 for (int replica = 0; replica < replica_count(); ++replica) {
108 StrAppend(&output, operator()(replica, computation), " ");
109 }
110 StrAppend(&output, "\n");
111 }
112 return output;
113 }
114
DeviceId(int replica,int computation,int replica_count,int computation_count)115 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
116 int replica_count,
117 int computation_count) {
118 TF_RET_CHECK(replica < replica_count);
119 TF_RET_CHECK(computation < computation_count);
120
121 return computation * replica_count + replica;
122 }
123
AssignDevices(int replica_count,int computation_count)124 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
125 int replica_count, int computation_count) {
126 DeviceAssignment assignment(replica_count, computation_count);
127 for (int replica = 0; replica < replica_count; ++replica) {
128 for (int computation = 0; computation < computation_count; ++computation) {
129 TF_ASSIGN_OR_RETURN(
130 int device_id,
131 DeviceId(replica, computation, replica_count, computation_count));
132 assignment(replica, computation) = device_id;
133 }
134 }
135 return std::move(assignment);
136 }
137
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)138 /* static */ void ComputationPlacer::RegisterComputationPlacer(
139 se::Platform::Id platform_id,
140 ComputationPlacerCreationFunction creation_function) {
141 tensorflow::mutex_lock lock(
142 ComputationPlacer::platform_computation_placer_mutex_);
143 auto* computation_placers = GetPlatformComputationPlacers();
144 CHECK(computation_placers->find(platform_id) == computation_placers->end());
145 (*computation_placers)[platform_id].creation_function = creation_function;
146 }
147
GetForPlatform(const se::Platform * platform)148 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
149 const se::Platform* platform) {
150 tensorflow::mutex_lock lock(
151 ComputationPlacer::platform_computation_placer_mutex_);
152 auto* computation_placers = GetPlatformComputationPlacers();
153
154 auto it = computation_placers->find(platform->id());
155 if (it == computation_placers->end()) {
156 return NotFound(
157 "could not find registered computation placer for platform %s -- check "
158 "target linkage",
159 platform->Name());
160 }
161
162 if (it->second.placer == nullptr) {
163 // Lazily create the computation placer the first time it is needed.
164 it->second.placer = (*it->second.creation_function)();
165 }
166
167 return it->second.placer.get();
168 }
169
170 /* static */ tensorflow::mutex
171 ComputationPlacer::platform_computation_placer_mutex_(
172 tensorflow::LINKER_INITIALIZED);
173
174 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()175 ComputationPlacer::GetPlatformComputationPlacers() {
176 static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
177 return r;
178 }
179
180 } // namespace xla
181
CreateComputationPlacer()182 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
183 return absl::make_unique<xla::ComputationPlacer>();
184 }
185
InitModule()186 static bool InitModule() {
187 xla::ComputationPlacer::RegisterComputationPlacer(
188 stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
189 xla::ComputationPlacer::RegisterComputationPlacer(
190 stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
191 xla::ComputationPlacer::RegisterComputationPlacer(
192 stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
193 return true;
194 }
195 static bool module_initialized = InitModule();
196