• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
36 
37 using absl::StrAppend;
38 using absl::StrCat;
39 
40 namespace xla {
41 
ReplicaIdForDeviceOrdinal(int device_ordinal) const42 StatusOr<int> DeviceAssignment::ReplicaIdForDeviceOrdinal(
43     int device_ordinal) const {
44   absl::optional<int> replica_id;
45   for (int64 r = 0; r < replica_count(); ++r) {
46     for (int64 c = 0; c < computation_count(); ++c) {
47       if ((*this)(r, c) == device_ordinal) {
48         if (replica_id.has_value()) {
49           return InternalError(
50               "Device ordinal %d appears twice in DeviceAssignment? %s",
51               device_ordinal, ToString());
52         }
53         replica_id = r;
54       }
55     }
56   }
57   if (!replica_id.has_value()) {
58     return InternalError(
59         "Device ordinal %d doesn't appear in DeviceAssignment %s",
60         device_ordinal, ToString());
61   }
62   return *replica_id;
63 }
64 
Serialize(DeviceAssignmentProto * proto) const65 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
66   proto->set_replica_count(replica_count());
67   proto->set_computation_count(computation_count());
68   for (int computation = 0; computation < computation_count(); ++computation) {
69     DeviceAssignmentProto::ComputationDevice* computation_device =
70         proto->add_computation_devices();
71     for (int replica = 0; replica < replica_count(); ++replica) {
72       computation_device->add_replica_device_ids((*this)(replica, computation));
73     }
74   }
75   return Status::OK();
76 }
77 
78 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)79 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
80   TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
81   if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
82     return InvalidArgument(
83         "Invalid device assignment topology: replica_count=%d, "
84         "computation_count=%d",
85         proto.replica_count(), proto.computation_count());
86   }
87   auto assignment = absl::make_unique<DeviceAssignment>(
88       proto.replica_count(), proto.computation_count());
89   for (int computation = 0; computation < proto.computation_count();
90        ++computation) {
91     const auto& computation_device = proto.computation_devices(computation);
92     TF_RET_CHECK(computation_device.replica_device_ids_size() ==
93                  proto.replica_count());
94     for (int replica = 0; replica < proto.replica_count(); ++replica) {
95       (*assignment)(replica, computation) =
96           computation_device.replica_device_ids(replica);
97     }
98   }
99   return std::move(assignment);
100 }
101 
ToString() const102 string DeviceAssignment::ToString() const {
103   string output = StrCat("Computations: ", computation_count(),
104                          " Replicas: ", replica_count(), "\n");
105   for (int computation = 0; computation < computation_count(); ++computation) {
106     StrAppend(&output, "Computation ", computation, ": ");
107     for (int replica = 0; replica < replica_count(); ++replica) {
108       StrAppend(&output, operator()(replica, computation), " ");
109     }
110     StrAppend(&output, "\n");
111   }
112   return output;
113 }
114 
DeviceId(int replica,int computation,int replica_count,int computation_count)115 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
116                                           int replica_count,
117                                           int computation_count) {
118   TF_RET_CHECK(replica < replica_count);
119   TF_RET_CHECK(computation < computation_count);
120 
121   return computation * replica_count + replica;
122 }
123 
AssignDevices(int replica_count,int computation_count)124 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
125     int replica_count, int computation_count) {
126   DeviceAssignment assignment(replica_count, computation_count);
127   for (int replica = 0; replica < replica_count; ++replica) {
128     for (int computation = 0; computation < computation_count; ++computation) {
129       TF_ASSIGN_OR_RETURN(
130           int device_id,
131           DeviceId(replica, computation, replica_count, computation_count));
132       assignment(replica, computation) = device_id;
133     }
134   }
135   return std::move(assignment);
136 }
137 
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)138 /* static */ void ComputationPlacer::RegisterComputationPlacer(
139     se::Platform::Id platform_id,
140     ComputationPlacerCreationFunction creation_function) {
141   tensorflow::mutex_lock lock(
142       ComputationPlacer::platform_computation_placer_mutex_);
143   auto* computation_placers = GetPlatformComputationPlacers();
144   CHECK(computation_placers->find(platform_id) == computation_placers->end());
145   (*computation_placers)[platform_id].creation_function = creation_function;
146 }
147 
GetForPlatform(const se::Platform * platform)148 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
149     const se::Platform* platform) {
150   tensorflow::mutex_lock lock(
151       ComputationPlacer::platform_computation_placer_mutex_);
152   auto* computation_placers = GetPlatformComputationPlacers();
153 
154   auto it = computation_placers->find(platform->id());
155   if (it == computation_placers->end()) {
156     return NotFound(
157         "could not find registered computation placer for platform %s -- check "
158         "target linkage",
159         platform->Name());
160   }
161 
162   if (it->second.placer == nullptr) {
163     // Lazily create the computation placer the first time it is needed.
164     it->second.placer = (*it->second.creation_function)();
165   }
166 
167   return it->second.placer.get();
168 }
169 
170 /* static */ tensorflow::mutex
171     ComputationPlacer::platform_computation_placer_mutex_(
172         tensorflow::LINKER_INITIALIZED);
173 
174 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()175 ComputationPlacer::GetPlatformComputationPlacers() {
176   static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
177   return r;
178 }
179 
180 }  // namespace xla
181 
CreateComputationPlacer()182 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
183   return absl::make_unique<xla::ComputationPlacer>();
184 }
185 
InitModule()186 static bool InitModule() {
187   xla::ComputationPlacer::RegisterComputationPlacer(
188       stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
189   xla::ComputationPlacer::RegisterComputationPlacer(
190       stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
191   xla::ComputationPlacer::RegisterComputationPlacer(
192       stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);
193   return true;
194 }
195 static bool module_initialized = InitModule();
196