1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 } // namespace
35
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37 const XLA_TpuProgram* xla_tpu_program) {
38 VLOG(1) << "ConstructExecutableInfo";
39 TpuSerializedProto serialized_executable_info = {};
40 StatusHelper status;
41 OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42 xla_tpu_program, &serialized_executable_info, status.c_status);
43 TPUExecutableInfoProto executable_info;
44 if (status.ok()) {
45 executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46 serialized_executable_info);
47 StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48 }
49 return executable_info;
50 }
51
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53 const XLA_TpuProgram* xla_tpu_program) {
54 VLOG(1) << "ConstructHostTransferInfo";
55 TpuSerializedProto serialized_host_transfer_info = {};
56 StatusHelper status;
57 OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58 xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59 TPUHostTransferInfoProto host_transfer_info;
60 if (status.ok()) {
61 host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62 serialized_host_transfer_info);
63 StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64 }
65 return host_transfer_info;
66 }
67
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69 const XLA_TpuProgram* xla_tpu_program) {
70 VLOG(1) << "ConstructHloMetadata";
71 TpuSerializedProto serialized_hlo_metadata = {};
72 StatusHelper status;
73 OpsApiFn()->TpuProgram_GetHloMetadataFn(
74 xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75 xla::HloProto hlo_metadata;
76 if (status.ok()) {
77 hlo_metadata =
78 se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79 StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80 }
81 return hlo_metadata;
82 }
83
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85 absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86 CHECK_GT(xla_tpu_programs.size(), 0);
87 CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88 "`TpuProgramGroup` instance is prohibited.";
89 set_tpu_programs(xla_tpu_programs);
90
91 CHECK_EQ(tpu_program_fingerprints_.size(), 0);
92 set_fingerprints();
93
94 std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
95 std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
96 std::vector<TPUHostTransferInfoProto> host_transfer_infos(
97 tpu_programs_.size());
98 std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
99 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
100 const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
101 bool may_modify_variables;
102 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
103 &may_modify_variables);
104 may_modify_variables_array[i] = may_modify_variables;
105 executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
106 host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
107 hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
108 }
109
110 may_modify_variables_ = may_modify_variables_array;
111 executable_infos_ = executable_infos;
112 host_transfer_infos_ = host_transfer_infos;
113 hlo_metadatas_ = hlo_metadatas;
114 RefreshHloMetadatasPtrs();
115 }
116
has_sharding_program() const117 bool TpuProgramGroup::has_sharding_program() const {
118 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
119 if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
120 return false;
121 }
122 }
123 return true;
124 }
125
program_count() const126 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
127
program_size() const128 int64_t TpuProgramGroup::program_size() const {
129 int64_t total_size = 0;
130 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
131 total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
132 }
133 return total_size;
134 }
135
LogProgramMemorySummary()136 bool TpuProgramGroup::LogProgramMemorySummary() {
137 bool success = true;
138 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
139 success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
140 }
141 return success;
142 }
143
UnloadAndDestroyPrograms()144 void TpuProgramGroup::UnloadAndDestroyPrograms() {
145 for (XLA_TpuProgram* tpu_program : tpu_programs_) {
146 StatusHelper status;
147 OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
148 auto s = status.status();
149 if (!s.ok()) {
150 LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
151 }
152 }
153 tpu_programs_.clear();
154 }
155
TpuProgramGroup(TpuProgramGroup && other)156 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
157 : may_modify_variables_(std::move(other.may_modify_variables_)),
158 tpu_programs_(std::move(other.tpu_programs_)),
159 executable_infos_(std::move(other.executable_infos_)),
160 host_transfer_infos_(std::move(other.host_transfer_infos_)),
161 hlo_metadatas_(std::move(other.hlo_metadatas_)) {
162 RefreshHloMetadatasPtrs();
163 }
164
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)165 void TpuProgramGroup::set_hlo_metadatas(
166 absl::Span<const xla::HloProto> hlo_metadatas) {
167 hlo_metadatas_.resize(hlo_metadatas.size());
168 for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
169 hlo_metadatas_[i] = hlo_metadatas[i];
170 }
171 RefreshHloMetadatasPtrs();
172 }
173
hlo_metadatas() const174 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
175 return hlo_metadatas_ptrs_;
176 }
177
hlo_metadata(int index) const178 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
179 CHECK_GE(index, 0);
180 CHECK_LT(index, hlo_metadatas_ptrs_.size());
181 return hlo_metadatas_ptrs_[index];
182 }
183
RefreshHloMetadatasPtrs()184 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
185 hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
186 for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
187 hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
188 }
189 }
190
LogCompilationStats(const TpuCompilationCacheKey & key,absl::Duration duration)191 Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
192 absl::Duration duration) {
193 // A placeholder for tracking compilation statistics for future work. The
194 // implementation can be pushing into some external storage for analytics.
195 return Status::OK();
196 }
197
may_modify_variables_list() const198 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
199 return may_modify_variables_;
200 }
201
set_may_modify_variables(const std::vector<bool> & may_modify_variables)202 void TpuProgramGroup::set_may_modify_variables(
203 const std::vector<bool>& may_modify_variables) {
204 may_modify_variables_ = may_modify_variables;
205 }
206
may_modify_variables(int index) const207 bool TpuProgramGroup::may_modify_variables(int index) const {
208 CHECK_GE(index, 0);
209 CHECK_LT(index, tpu_programs_.size());
210 bool may_modify_variables;
211 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
212 &may_modify_variables);
213 return may_modify_variables;
214 }
215
tpu_programs() const216 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
217 return tpu_programs_;
218 }
219
fingerprints() const220 const std::vector<std::string>& TpuProgramGroup::fingerprints() const {
221 return tpu_program_fingerprints_;
222 }
223
set_fingerprints()224 void TpuProgramGroup::set_fingerprints() {
225 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
226 TpuProgramFingerprint fingerprint =
227 OpsApiFn()->TpuProgram_GetFingerprintFn(tpu_program);
228 tpu_program_fingerprints_.emplace_back(
229 std::string(fingerprint.bytes, fingerprint.size));
230 OpsApiFn()->TpuProgram_DestroyFingerprintFn(fingerprint);
231 }
232 }
233
fingerprint(int index) const234 const std::string& TpuProgramGroup::fingerprint(int index) const {
235 return fingerprints().at(index);
236 }
237
tpu_program(int index) const238 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
239 CHECK_GE(index, 0);
240 CHECK_LT(index, tpu_programs_.size());
241 return tpu_programs_[index];
242 }
243
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)244 void TpuProgramGroup::set_tpu_programs(
245 absl::Span<XLA_TpuProgram* const> tpu_programs) {
246 tpu_programs_.resize(tpu_programs.size());
247 for (size_t i = 0; i < tpu_programs.size(); ++i) {
248 tpu_programs_[i] = tpu_programs[i];
249 }
250 }
251
executable_info(int index) const252 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
253 int index) const {
254 CHECK_GE(index, 0);
255 CHECK_LT(index, executable_infos_.size());
256 return executable_infos_[index];
257 }
258
host_transfer_info(int index) const259 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
260 int index) const {
261 CHECK_GE(index, 0);
262 CHECK_LT(index, host_transfer_infos_.size());
263 return host_transfer_infos_[index];
264 }
265
266 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)267 Status TpuProgramGroup::CompileAndBuild(
268 const TpuCompilationRequestProto& compilation_request,
269 const XLA_TpuMeshState* mesh_state,
270 TpuProgramGroupInterface* tpu_program_group_interface) {
271 se_tpu::SerializedProto serialized_compilation_request =
272 se_tpu::SerializeProto(compilation_request);
273 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
274 se_tpu::SerializedProto_Free(serialized_compilation_request);
275 });
276 size_t count = 0;
277 XLA_TpuProgram** xla_tpu_programs = nullptr;
278 StatusHelper status;
279 OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
280 mesh_state, &xla_tpu_programs,
281 &count, status.c_status);
282 if (!status.ok()) {
283 VLOG(1) << "Run CompileAndBuild failed.";
284 return status.status();
285 }
286
287 // SPMD could return 1 result for all partitions.
288 TF_RET_CHECK(count == 1 ||
289 count == compilation_request.metadata().num_cores_per_replica());
290
291 VLOG(1) << "Initialize TpuProgramGroup.";
292 TpuProgramGroup* tpu_program_group =
293 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
294 tpu_program_group->Initialize(
295 absl::MakeConstSpan(&xla_tpu_programs[0], count));
296 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
297 return status.status();
298 }
299
300 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)301 Status TpuProgramGroup::CompileAndBuild(
302 const xrt::XLAComputation& xrt_computation_proto,
303 const XLA_TpuMeshState* mesh_state,
304 TpuProgramGroupInterface* tpu_program_group_interface) {
305 se_tpu::SerializedProto serialized_compilation_request =
306 se_tpu::SerializeProto(xrt_computation_proto);
307 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
308 se_tpu::SerializedProto_Free(serialized_compilation_request);
309 });
310 size_t count = 0;
311 XLA_TpuProgram** xla_tpu_programs = nullptr;
312 StatusHelper status;
313 OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
314 mesh_state, &xla_tpu_programs,
315 &count, status.c_status);
316 if (!status.ok()) {
317 VLOG(1) << "Run CompileAndBuild failed.";
318 return status.status();
319 }
320
321 // SPMD could return 1 result for all partitions.
322 int num_cores_per_replica =
323 xrt_computation_proto.config().num_cores_per_replica()
324 ? xrt_computation_proto.config().num_cores_per_replica()
325 : 1;
326 TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
327 VLOG(1) << "Initialize TpuProgramGroup.";
328 TpuProgramGroup* tpu_program_group =
329 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
330 tpu_program_group->Initialize(
331 absl::MakeConstSpan(&xla_tpu_programs[0], count));
332 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
333 return status.status();
334 }
335
tpu_programs(TpuProgramShardingType sharding_type) const336 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
337 TpuProgramShardingType sharding_type) const {
338 std::vector<XLA_TpuProgram*> tpu_programs;
339 tpu_programs.reserve(tpu_programs_.size());
340 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
341 if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
342 tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
343 tpu_programs_[i], sharding_type));
344 CHECK_NE(tpu_programs[i], nullptr);
345 }
346 }
347 return tpu_programs;
348 }
349
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)350 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
351 const std::vector<TpuSerializedProto>& rpc_response_protos) {
352 std::vector<XLA_TpuProgram*> tpu_programs;
353 tpu_programs.resize(rpc_response_protos.size());
354
355 for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
356 StatusHelper status;
357 auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
358 OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
359 rpc_response_protos[i], xla_tpu_program, status.c_status);
360 if (!status.status().ok()) {
361 OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
362 return status.status();
363 }
364 tpu_programs[i] = xla_tpu_program;
365 }
366
367 Initialize(tpu_programs);
368 return Status::OK();
369 }
370
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const371 Status TpuProgramGroup::SerializeExecutable(
372 int index, TpuExecutableSerializedProto* executable) const {
373 CHECK_GE(index, 0);
374 CHECK_LT(index, tpu_programs_.size());
375 StatusHelper status;
376 OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
377 executable, status.c_status);
378 return status.status();
379 }
380
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const381 Status TpuProgramGroup::SerializeCompilerMetadata(
382 int index, CompilerMetadataSerializedProto* compiler_metadata) const {
383 CHECK_GE(index, 0);
384 CHECK_LT(index, tpu_programs_.size());
385 StatusHelper status;
386 OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
387 tpu_programs_[index], compiler_metadata, status.c_status);
388 return status.status();
389 }
390 } // namespace tpu
391 } // namespace tensorflow
392