1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/status/status.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
27 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
30 #include "tensorflow/compiler/xla/shape.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/casts.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/stream_executor/device_memory.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/stream.h"
39 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
40 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
41 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
42 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
43
44 namespace tf_tpu = tensorflow::tpu;
45
46 namespace xla {
47 namespace {
48
49 class TpuDeviceState : public LocalDeviceState {
50 public:
51 TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
52 bool asynchronous);
53
54 Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
55 se::Stream* dst_stream,
56 se::DeviceMemoryBase src_buffer,
57 se::DeviceMemoryBase dst_buffer) override;
58 };
59
TpuDeviceState(se::StreamExecutor * executor,LocalClient * client,bool asynchronous)60 TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
61 LocalClient* client, bool asynchronous)
62 : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
63 asynchronous,
64 /*allow_event_reuse=*/false) {}
65
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)66 Status TpuDeviceState::ThenMemcpyDeviceToDevice(
67 se::Stream* transfer_stream, se::Stream* dst_stream,
68 se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
69 auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
70 transfer_stream->implementation());
71 TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
72 src_buffer, dst_buffer));
73 return Status::OK();
74 }
75
76 class PjRtTpuClient : public PjRtStreamExecutorClient {
77 public:
78 PjRtTpuClient(LocalClient* client,
79 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
80 int task_id);
81
82 StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
83 int num_replicas, int num_partitions) const override;
84
EnqueueD2DTransfersOnSrcStream() const85 bool EnqueueD2DTransfersOnSrcStream() const override { return false; }
86
87 StatusOr<absl::optional<std::string>> ExecutableFingerprint(
88 const PjRtExecutable& executable) const override;
89 };
90
PjRtTpuClient(LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int task_id)91 PjRtTpuClient::PjRtTpuClient(
92 LocalClient* client,
93 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id)
94 : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), task_id,
95 /*allocator=*/nullptr,
96 /*host_memory_allocator=*/nullptr,
97 /*should_stage_host_to_device_transfers=*/false,
98 /*gpu_run_options=*/nullptr) {}
99
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const100 StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
101 int num_replicas, int num_partitions) const {
102 tf_tpu::TpuPlatformInterface* platform =
103 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
104 tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
105 int num_local_devices = host.Cores(kTensorCore).size();
106 if (num_replicas * num_partitions <= num_local_devices) {
107 return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
108 num_partitions);
109 }
110 // Fallback to default global device assignment if we can't run locally.
111 return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
112 num_partitions);
113 }
114
ExecutableFingerprint(const PjRtExecutable & executable) const115 StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
116 const PjRtExecutable& executable) const {
117 if (executable.client() != this) {
118 return InvalidArgument(
119 "Passed executable from different client (platform '%s') to "
120 "PjRtTpuClient::ExecutableFingerprint",
121 executable.client()->platform_name());
122 }
123 if (executable.num_partitions() > 1) {
124 LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
125 "executables, fingerprint may not be unique.";
126 }
127 xla::TpuExecutableInterface* tpu_executable =
128 tensorflow::down_cast<xla::TpuExecutableInterface*>(
129 tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(
130 &executable)
131 ->executables()[0]
132 ->executable());
133 return absl::optional<std::string>(tpu_executable->fingerprint());
134 }
135
GetTpuDevices(LocalClient * client,std::vector<std::unique_ptr<LocalDeviceState>> local_device_states)136 StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
137 LocalClient* client,
138 std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
139 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
140 tf_tpu::TpuTopologyExternal topology =
141 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
142
143 std::map<int, int> core_id_to_device_ordinal;
144 for (int i = 0; i < client->device_count(); ++i) {
145 se::StreamExecutor* executor =
146 client->backend().stream_executor(i).ValueOrDie();
147 tf_tpu::TpuExecutorInterface* tpu_executor =
148 tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
149 executor->implementation());
150 core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
151 }
152
153 for (const tf_tpu::TpuCoreLocationExternal& core :
154 topology.cores(TpuCoreTypeEnum::kTensorCore)) {
155 auto it = core_id_to_device_ordinal.find(core.Id());
156 int device_ordinal =
157 (it != core_id_to_device_ordinal.end()) ? it->second : -1;
158 int task_id = topology.IdForHost(core.host_coordinates());
159 const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
160 std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
161 std::unique_ptr<LocalDeviceState> local_device_state;
162 if (device_ordinal >= 0) {
163 local_device_state = std::move(local_device_states[device_ordinal]);
164 }
165 auto device = absl::make_unique<PjRtTpuDevice>(
166 core, std::move(local_device_state), task_id, coords_array,
167 std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
168 devices.push_back(std::move(device));
169 }
170 return devices;
171 }
172
173 } // namespace
174
GetTpuClient(bool asynchronous,absl::Duration init_retry_timeout)175 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
176 bool asynchronous, absl::Duration init_retry_timeout) {
177 tf_tpu::TpuPlatformInterface* platform =
178 tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
179 /*initialize_platform=*/true, /*num_tries=*/1);
180 if (platform == nullptr) {
181 return InvalidArgument("TpuPlatform is not available.");
182 }
183 // NOTE: We retry in a loop since some pod failures are transient (e.g. some
184 // RPCs may timeout waiting for other hosts to come up, but will succeed
185 // at a later point if retried).
186 auto start = absl::Now();
187 // TODO(b/165870356): TpuPlatform::Initialized() always returns true!
188 auto status = platform->Initialize({});
189 while (!platform->Initialized()) {
190 status = platform->Initialize({});
191 if (!status.ok()) {
192 LOG(ERROR) << "Platform initialization failed: " << status;
193 if ((absl::Now() - start) >= init_retry_timeout) {
194 return status;
195 }
196 }
197 }
198 if (platform->VisibleDeviceCount() <= 0) {
199 return InvalidArgument("No TPU devices found.");
200 }
201 LocalClientOptions options;
202 options.set_platform(platform);
203 TF_ASSIGN_OR_RETURN(LocalClient * client,
204 ClientLibrary::GetOrCreateLocalClient(options));
205
206 std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
207 local_device_states.reserve(client->device_count());
208 for (int i = 0; i < client->device_count(); ++i) {
209 se::StreamExecutor* executor =
210 client->backend().stream_executor(i).ValueOrDie();
211 local_device_states.push_back(
212 absl::make_unique<TpuDeviceState>(executor, client, asynchronous));
213 }
214
215 TF_ASSIGN_OR_RETURN(auto devices,
216 GetTpuDevices(client, std::move(local_device_states)));
217 int task_id = platform->GetTpuHostLocation().Id();
218
219 return std::shared_ptr<PjRtClient>(
220 absl::make_unique<PjRtTpuClient>(client, std::move(devices), task_id));
221 }
222
223 } // namespace xla
224