1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
35
36 using absl::StrAppend;
37 using absl::StrCat;
38
39 namespace xla {
40
Serialize(DeviceAssignmentProto * proto) const41 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
42 proto->set_replica_count(replica_count());
43 proto->set_computation_count(computation_count());
44 for (int computation = 0; computation < computation_count(); ++computation) {
45 DeviceAssignmentProto::ComputationDevice* computation_device =
46 proto->add_computation_devices();
47 for (int replica = 0; replica < replica_count(); ++replica) {
48 computation_device->add_replica_device_ids((*this)(replica, computation));
49 }
50 }
51 return Status::OK();
52 }
53
54 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)55 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
56 TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
57 if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
58 return InvalidArgument(
59 "Invalid device assignment topology: replica_count=%d, "
60 "computation_count=%d",
61 proto.replica_count(), proto.computation_count());
62 }
63 auto assignment = absl::make_unique<DeviceAssignment>(
64 proto.replica_count(), proto.computation_count());
65 for (int computation = 0; computation < proto.computation_count();
66 ++computation) {
67 const auto& computation_device = proto.computation_devices(computation);
68 TF_RET_CHECK(computation_device.replica_device_ids_size() ==
69 proto.replica_count());
70 for (int replica = 0; replica < proto.replica_count(); ++replica) {
71 (*assignment)(replica, computation) =
72 computation_device.replica_device_ids(replica);
73 }
74 }
75 return std::move(assignment);
76 }
77
ToString() const78 string DeviceAssignment::ToString() const {
79 string output = StrCat("Computations: ", computation_count(),
80 " Replicas: ", replica_count(), "\n");
81 for (int computation = 0; computation < computation_count(); ++computation) {
82 StrAppend(&output, "Computation ", computation, ": ");
83 for (int replica = 0; replica < replica_count(); ++replica) {
84 StrAppend(&output, operator()(replica, computation), " ");
85 }
86 StrAppend(&output, "\n");
87 }
88 return output;
89 }
90
DeviceId(int replica,int computation,int replica_count,int computation_count)91 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
92 int replica_count,
93 int computation_count) {
94 TF_RET_CHECK(replica < replica_count);
95 TF_RET_CHECK(computation < computation_count);
96
97 return computation * replica_count + replica;
98 }
99
AssignDevices(int replica_count,int computation_count)100 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
101 int replica_count, int computation_count) {
102 DeviceAssignment assignment(replica_count, computation_count);
103 for (int replica = 0; replica < replica_count; ++replica) {
104 for (int computation = 0; computation < computation_count; ++computation) {
105 TF_ASSIGN_OR_RETURN(
106 int device_id,
107 DeviceId(replica, computation, replica_count, computation_count));
108 assignment(replica, computation) = device_id;
109 }
110 }
111 return std::move(assignment);
112 }
113
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)114 /* static */ void ComputationPlacer::RegisterComputationPlacer(
115 se::Platform::Id platform_id,
116 ComputationPlacerCreationFunction creation_function) {
117 tensorflow::mutex_lock lock(
118 ComputationPlacer::platform_computation_placer_mutex_);
119 auto* computation_placers = GetPlatformComputationPlacers();
120 CHECK(computation_placers->find(platform_id) == computation_placers->end());
121 (*computation_placers)[platform_id].creation_function = creation_function;
122 }
123
GetForPlatform(const se::Platform * platform)124 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
125 const se::Platform* platform) {
126 tensorflow::mutex_lock lock(
127 ComputationPlacer::platform_computation_placer_mutex_);
128 auto* computation_placers = GetPlatformComputationPlacers();
129
130 auto it = computation_placers->find(platform->id());
131 if (it == computation_placers->end()) {
132 return NotFound(
133 "could not find registered computation placer for platform %s -- check "
134 "target linkage",
135 platform->Name());
136 }
137
138 if (it->second.placer == nullptr) {
139 // Lazily create the computation placer the first time it is needed.
140 it->second.placer = (*it->second.creation_function)();
141 }
142
143 return it->second.placer.get();
144 }
145
146 /* static */ tensorflow::mutex
147 ComputationPlacer::platform_computation_placer_mutex_(
148 tensorflow::LINKER_INITIALIZED);
149
150 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()151 ComputationPlacer::GetPlatformComputationPlacers() {
152 static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
153 return r;
154 }
155
156 } // namespace xla
157
CreateComputationPlacer()158 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
159 return absl::make_unique<xla::ComputationPlacer>();
160 }
161
InitModule()162 static bool InitModule() {
163 xla::ComputationPlacer::RegisterComputationPlacer(
164 stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
165 xla::ComputationPlacer::RegisterComputationPlacer(
166 stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
167 return true;
168 }
169 static bool module_initialized = InitModule();
170