• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/global_device_id.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
37 #include "tensorflow/stream_executor/host/host_platform_id.h"
38 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
39 
40 using absl::StrAppend;
41 using absl::StrCat;
42 
43 namespace xla {
44 
LogicalIdForDevice(GlobalDeviceId device_id) const45 StatusOr<DeviceAssignment::LogicalID> DeviceAssignment::LogicalIdForDevice(
46     GlobalDeviceId device_id) const {
47   absl::optional<DeviceAssignment::LogicalID> logical_id;
48   for (int r = 0; r < replica_count(); ++r) {
49     for (int c = 0; c < computation_count(); ++c) {
50       if ((*this)(r, c) == device_id.value()) {
51         if (logical_id.has_value()) {
52           return InternalError(
53               "Device %d appears twice in DeviceAssignment: %s",
54               device_id.value(), ToString());
55         }
56         logical_id.emplace(DeviceAssignment::LogicalID{r, c});
57       }
58     }
59   }
60   if (logical_id.has_value()) {
61     return *logical_id;
62   } else {
63     return InternalError("Device %d doesn't appear in DeviceAssignment: %s",
64                          device_id.value(), ToString());
65   }
66 }
67 
ReplicaIdForDevice(GlobalDeviceId device_id) const68 StatusOr<int> DeviceAssignment::ReplicaIdForDevice(
69     GlobalDeviceId device_id) const {
70   TF_ASSIGN_OR_RETURN(const LogicalID logical_id,
71                       LogicalIdForDevice(device_id));
72   return logical_id.replica_id;
73 }
74 
Serialize(DeviceAssignmentProto * proto) const75 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
76   proto->set_replica_count(replica_count());
77   proto->set_computation_count(computation_count());
78   for (int computation = 0; computation < computation_count(); ++computation) {
79     DeviceAssignmentProto::ComputationDevice* computation_device =
80         proto->add_computation_devices();
81     for (int replica = 0; replica < replica_count(); ++replica) {
82       computation_device->add_replica_device_ids((*this)(replica, computation));
83     }
84   }
85   return Status::OK();
86 }
87 
88 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)89 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
90   TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
91   if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
92     return InvalidArgument(
93         "Invalid device assignment topology: replica_count=%d, "
94         "computation_count=%d",
95         proto.replica_count(), proto.computation_count());
96   }
97   auto assignment = absl::make_unique<DeviceAssignment>(
98       proto.replica_count(), proto.computation_count());
99   for (int computation = 0; computation < proto.computation_count();
100        ++computation) {
101     const auto& computation_device = proto.computation_devices(computation);
102     TF_RET_CHECK(computation_device.replica_device_ids_size() ==
103                  proto.replica_count());
104     for (int replica = 0; replica < proto.replica_count(); ++replica) {
105       (*assignment)(replica, computation) =
106           computation_device.replica_device_ids(replica);
107     }
108   }
109   return std::move(assignment);
110 }
111 
ToString() const112 string DeviceAssignment::ToString() const {
113   string output = StrCat("Computations: ", computation_count(),
114                          " Replicas: ", replica_count(), "\n");
115   for (int computation = 0; computation < computation_count(); ++computation) {
116     StrAppend(&output, "Computation ", computation, ": ");
117     for (int replica = 0; replica < replica_count(); ++replica) {
118       StrAppend(&output, operator()(replica, computation), " ");
119     }
120     StrAppend(&output, "\n");
121   }
122   return output;
123 }
124 
DeviceId(int replica,int computation,int replica_count,int computation_count)125 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
126                                           int replica_count,
127                                           int computation_count) {
128   TF_RET_CHECK(replica < replica_count);
129   TF_RET_CHECK(computation < computation_count);
130 
131   return computation * replica_count + replica;
132 }
133 
AssignDevices(int replica_count,int computation_count)134 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
135     int replica_count, int computation_count) {
136   DeviceAssignment assignment(replica_count, computation_count);
137   for (int replica = 0; replica < replica_count; ++replica) {
138     for (int computation = 0; computation < computation_count; ++computation) {
139       TF_ASSIGN_OR_RETURN(
140           int device_id,
141           DeviceId(replica, computation, replica_count, computation_count));
142       assignment(replica, computation) = device_id;
143     }
144   }
145   return std::move(assignment);
146 }
147 
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)148 /* static */ void ComputationPlacer::RegisterComputationPlacer(
149     se::Platform::Id platform_id,
150     ComputationPlacerCreationFunction creation_function) {
151   tensorflow::mutex_lock lock(
152       ComputationPlacer::platform_computation_placer_mutex_);
153   auto* computation_placers = GetPlatformComputationPlacers();
154   CHECK(computation_placers->find(platform_id) == computation_placers->end());
155   (*computation_placers)[platform_id].creation_function = creation_function;
156 }
157 
GetForPlatform(const se::Platform * platform)158 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
159     const se::Platform* platform) {
160   tensorflow::mutex_lock lock(
161       ComputationPlacer::platform_computation_placer_mutex_);
162   auto* computation_placers = GetPlatformComputationPlacers();
163 
164   auto it = computation_placers->find(platform->id());
165   if (it == computation_placers->end()) {
166     return NotFound(
167         "could not find registered computation placer for platform %s -- check "
168         "target linkage",
169         platform->Name());
170   }
171 
172   if (it->second.placer == nullptr) {
173     // Lazily create the computation placer the first time it is needed.
174     it->second.placer = (*it->second.creation_function)();
175   }
176 
177   return it->second.placer.get();
178 }
179 
180 /* static */ tensorflow::mutex
181     ComputationPlacer::platform_computation_placer_mutex_(
182         tensorflow::LINKER_INITIALIZED);
183 
184 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()185 ComputationPlacer::GetPlatformComputationPlacers() {
186   static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
187   return r;
188 }
189 
190 }  // namespace xla
191 
CreateComputationPlacer()192 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
193   return absl::make_unique<xla::ComputationPlacer>();
194 }
195 
InitModule()196 static bool InitModule() {
197   xla::ComputationPlacer::RegisterComputationPlacer(
198       stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
199   xla::ComputationPlacer::RegisterComputationPlacer(
200       stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
201   xla::ComputationPlacer::RegisterComputationPlacer(
202       stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
203   return true;
204 }
205 static bool module_initialized = InitModule();
206