• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/computation_placer.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
35 
36 using absl::StrAppend;
37 using absl::StrCat;
38 
39 namespace xla {
40 
Serialize(DeviceAssignmentProto * proto) const41 Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
42   proto->set_replica_count(replica_count());
43   proto->set_computation_count(computation_count());
44   for (int computation = 0; computation < computation_count(); ++computation) {
45     DeviceAssignmentProto::ComputationDevice* computation_device =
46         proto->add_computation_devices();
47     for (int replica = 0; replica < replica_count(); ++replica) {
48       computation_device->add_replica_device_ids((*this)(replica, computation));
49     }
50   }
51   return Status::OK();
52 }
53 
54 /* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
Deserialize(const DeviceAssignmentProto & proto)55 DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
56   TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
57   if (proto.replica_count() <= 0 || proto.computation_count() <= 0) {
58     return InvalidArgument(
59         "Invalid device assignment topology: replica_count=%d, "
60         "computation_count=%d",
61         proto.replica_count(), proto.computation_count());
62   }
63   auto assignment = absl::make_unique<DeviceAssignment>(
64       proto.replica_count(), proto.computation_count());
65   for (int computation = 0; computation < proto.computation_count();
66        ++computation) {
67     const auto& computation_device = proto.computation_devices(computation);
68     TF_RET_CHECK(computation_device.replica_device_ids_size() ==
69                  proto.replica_count());
70     for (int replica = 0; replica < proto.replica_count(); ++replica) {
71       (*assignment)(replica, computation) =
72           computation_device.replica_device_ids(replica);
73     }
74   }
75   return std::move(assignment);
76 }
77 
ToString() const78 string DeviceAssignment::ToString() const {
79   string output = StrCat("Computations: ", computation_count(),
80                          " Replicas: ", replica_count(), "\n");
81   for (int computation = 0; computation < computation_count(); ++computation) {
82     StrAppend(&output, "Computation ", computation, ": ");
83     for (int replica = 0; replica < replica_count(); ++replica) {
84       StrAppend(&output, operator()(replica, computation), " ");
85     }
86     StrAppend(&output, "\n");
87   }
88   return output;
89 }
90 
DeviceId(int replica,int computation,int replica_count,int computation_count)91 StatusOr<int> ComputationPlacer::DeviceId(int replica, int computation,
92                                           int replica_count,
93                                           int computation_count) {
94   TF_RET_CHECK(replica < replica_count);
95   TF_RET_CHECK(computation < computation_count);
96 
97   return computation * replica_count + replica;
98 }
99 
AssignDevices(int replica_count,int computation_count)100 StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
101     int replica_count, int computation_count) {
102   DeviceAssignment assignment(replica_count, computation_count);
103   for (int replica = 0; replica < replica_count; ++replica) {
104     for (int computation = 0; computation < computation_count; ++computation) {
105       TF_ASSIGN_OR_RETURN(
106           int device_id,
107           DeviceId(replica, computation, replica_count, computation_count));
108       assignment(replica, computation) = device_id;
109     }
110   }
111   return std::move(assignment);
112 }
113 
RegisterComputationPlacer(se::Platform::Id platform_id,ComputationPlacerCreationFunction creation_function)114 /* static */ void ComputationPlacer::RegisterComputationPlacer(
115     se::Platform::Id platform_id,
116     ComputationPlacerCreationFunction creation_function) {
117   tensorflow::mutex_lock lock(
118       ComputationPlacer::platform_computation_placer_mutex_);
119   auto* computation_placers = GetPlatformComputationPlacers();
120   CHECK(computation_placers->find(platform_id) == computation_placers->end());
121   (*computation_placers)[platform_id].creation_function = creation_function;
122 }
123 
GetForPlatform(const se::Platform * platform)124 /* static */ StatusOr<ComputationPlacer*> ComputationPlacer::GetForPlatform(
125     const se::Platform* platform) {
126   tensorflow::mutex_lock lock(
127       ComputationPlacer::platform_computation_placer_mutex_);
128   auto* computation_placers = GetPlatformComputationPlacers();
129 
130   auto it = computation_placers->find(platform->id());
131   if (it == computation_placers->end()) {
132     return NotFound(
133         "could not find registered computation placer for platform %s -- check "
134         "target linkage",
135         platform->Name());
136   }
137 
138   if (it->second.placer == nullptr) {
139     // Lazily create the computation placer the first time it is needed.
140     it->second.placer = (*it->second.creation_function)();
141   }
142 
143   return it->second.placer.get();
144 }
145 
146 /* static */ tensorflow::mutex
147     ComputationPlacer::platform_computation_placer_mutex_(
148         tensorflow::LINKER_INITIALIZED);
149 
150 /* static */ std::map<se::Platform::Id, ComputationPlacer::State>*
GetPlatformComputationPlacers()151 ComputationPlacer::GetPlatformComputationPlacers() {
152   static auto* r = new std::map<se::Platform::Id, ComputationPlacer::State>;
153   return r;
154 }
155 
156 }  // namespace xla
157 
CreateComputationPlacer()158 static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
159   return absl::make_unique<xla::ComputationPlacer>();
160 }
161 
InitModule()162 static bool InitModule() {
163   xla::ComputationPlacer::RegisterComputationPlacer(
164       stream_executor::host::kHostPlatformId, &CreateComputationPlacer);
165   xla::ComputationPlacer::RegisterComputationPlacer(
166       stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
167   return true;
168 }
169 static bool module_initialized = InitModule();
170