1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 } // namespace
35
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37 const XLA_TpuProgram* xla_tpu_program) {
38 VLOG(1) << "ConstructExecutableInfo";
39 TpuSerializedProto serialized_executable_info = {};
40 StatusHelper status;
41 OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42 xla_tpu_program, &serialized_executable_info, status.c_status);
43 TPUExecutableInfoProto executable_info;
44 if (status.ok()) {
45 executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46 serialized_executable_info);
47 StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48 }
49 return executable_info;
50 }
51
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53 const XLA_TpuProgram* xla_tpu_program) {
54 VLOG(1) << "ConstructHostTransferInfo";
55 TpuSerializedProto serialized_host_transfer_info = {};
56 StatusHelper status;
57 OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58 xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59 TPUHostTransferInfoProto host_transfer_info;
60 if (status.ok()) {
61 host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62 serialized_host_transfer_info);
63 StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64 }
65 return host_transfer_info;
66 }
67
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69 const XLA_TpuProgram* xla_tpu_program) {
70 VLOG(1) << "ConstructHloMetadata";
71 TpuSerializedProto serialized_hlo_metadata = {};
72 StatusHelper status;
73 OpsApiFn()->TpuProgram_GetHloMetadataFn(
74 xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75 xla::HloProto hlo_metadata;
76 if (status.ok()) {
77 hlo_metadata =
78 se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79 StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80 }
81 return hlo_metadata;
82 }
83
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85 absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86 CHECK_GT(xla_tpu_programs.size(), 0);
87 CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88 "`TpuProgramGroup` instance is prohibited.";
89 set_tpu_programs(xla_tpu_programs);
90
91 std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
92 std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
93 std::vector<TPUHostTransferInfoProto> host_transfer_infos(
94 tpu_programs_.size());
95 std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
96 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
97 const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
98 bool may_modify_variables;
99 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
100 &may_modify_variables);
101 may_modify_variables_array[i] = may_modify_variables;
102 executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
103 host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
104 hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
105 }
106
107 may_modify_variables_ = may_modify_variables_array;
108 executable_infos_ = executable_infos;
109 host_transfer_infos_ = host_transfer_infos;
110 hlo_metadatas_ = hlo_metadatas;
111 RefreshHloMetadatasPtrs();
112 }
113
has_sharding_program() const114 bool TpuProgramGroup::has_sharding_program() const {
115 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
116 if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
117 return false;
118 }
119 }
120 return true;
121 }
122
program_count() const123 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
124
program_size() const125 int64_t TpuProgramGroup::program_size() const {
126 int64_t total_size = 0;
127 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
128 total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
129 }
130 return total_size;
131 }
132
LogProgramMemorySummary()133 bool TpuProgramGroup::LogProgramMemorySummary() {
134 bool success = true;
135 for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
136 success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
137 }
138 return success;
139 }
140
UnloadAndDestroyPrograms()141 void TpuProgramGroup::UnloadAndDestroyPrograms() {
142 for (XLA_TpuProgram* tpu_program : tpu_programs_) {
143 StatusHelper status;
144 OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
145 auto s = status.status();
146 if (!s.ok()) {
147 LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
148 }
149 }
150 tpu_programs_.clear();
151 }
152
TpuProgramGroup(TpuProgramGroup && other)153 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
154 : may_modify_variables_(std::move(other.may_modify_variables_)),
155 tpu_programs_(std::move(other.tpu_programs_)),
156 executable_infos_(std::move(other.executable_infos_)),
157 host_transfer_infos_(std::move(other.host_transfer_infos_)),
158 hlo_metadatas_(std::move(other.hlo_metadatas_)) {
159 RefreshHloMetadatasPtrs();
160 }
161
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)162 void TpuProgramGroup::set_hlo_metadatas(
163 absl::Span<const xla::HloProto> hlo_metadatas) {
164 hlo_metadatas_.resize(hlo_metadatas.size());
165 for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
166 hlo_metadatas_[i] = hlo_metadatas[i];
167 }
168 RefreshHloMetadatasPtrs();
169 }
170
hlo_metadatas() const171 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
172 return hlo_metadatas_ptrs_;
173 }
174
hlo_metadata(int index) const175 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
176 CHECK_GE(index, 0);
177 CHECK_LT(index, hlo_metadatas_ptrs_.size());
178 return hlo_metadatas_ptrs_[index];
179 }
180
RefreshHloMetadatasPtrs()181 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
182 hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
183 for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
184 hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
185 }
186 }
187
LogCompilationStats(const TpuCompilationCacheKey & key,absl::Duration duration)188 Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
189 absl::Duration duration) {
190 // A placeholder for tracking compilation statistics for future work. The
191 // implementation can be pushing into some external storage for analytics.
192 return Status::OK();
193 }
194
may_modify_variables_list() const195 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
196 return may_modify_variables_;
197 }
198
set_may_modify_variables(const std::vector<bool> & may_modify_variables)199 void TpuProgramGroup::set_may_modify_variables(
200 const std::vector<bool>& may_modify_variables) {
201 may_modify_variables_ = may_modify_variables;
202 }
203
may_modify_variables(int index) const204 bool TpuProgramGroup::may_modify_variables(int index) const {
205 CHECK_GE(index, 0);
206 CHECK_LT(index, tpu_programs_.size());
207 bool may_modify_variables;
208 OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
209 &may_modify_variables);
210 return may_modify_variables;
211 }
212
tpu_programs() const213 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
214 return tpu_programs_;
215 }
216
tpu_program(int index) const217 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
218 CHECK_GE(index, 0);
219 CHECK_LT(index, tpu_programs_.size());
220 return tpu_programs_[index];
221 }
222
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)223 void TpuProgramGroup::set_tpu_programs(
224 absl::Span<XLA_TpuProgram* const> tpu_programs) {
225 tpu_programs_.resize(tpu_programs.size());
226 for (size_t i = 0; i < tpu_programs.size(); ++i) {
227 tpu_programs_[i] = tpu_programs[i];
228 }
229 }
230
executable_info(int index) const231 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
232 int index) const {
233 CHECK_GE(index, 0);
234 CHECK_LT(index, executable_infos_.size());
235 return executable_infos_[index];
236 }
237
host_transfer_info(int index) const238 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
239 int index) const {
240 CHECK_GE(index, 0);
241 CHECK_LT(index, host_transfer_infos_.size());
242 return host_transfer_infos_[index];
243 }
244
245 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)246 Status TpuProgramGroup::CompileAndBuild(
247 const TpuCompilationRequestProto& compilation_request,
248 const XLA_TpuMeshState* mesh_state,
249 TpuProgramGroupInterface* tpu_program_group_interface) {
250 se_tpu::SerializedProto serialized_compilation_request =
251 se_tpu::SerializeProto(compilation_request);
252 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
253 se_tpu::SerializedProto_Free(serialized_compilation_request);
254 });
255 size_t count = 0;
256 XLA_TpuProgram** xla_tpu_programs = nullptr;
257 StatusHelper status;
258 OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
259 mesh_state, &xla_tpu_programs,
260 &count, status.c_status);
261 if (!status.ok()) {
262 VLOG(1) << "Run CompileAndBuild failed.";
263 return status.status();
264 }
265
266 // SPMD could return 1 result for all partitions.
267 TF_RET_CHECK(count == 1 ||
268 count == compilation_request.metadata().num_cores_per_replica());
269
270 VLOG(1) << "Initialize TpuProgramGroup.";
271 TpuProgramGroup* tpu_program_group =
272 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
273 tpu_program_group->Initialize(
274 absl::MakeConstSpan(&xla_tpu_programs[0], count));
275 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
276 return status.status();
277 }
278
279 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)280 Status TpuProgramGroup::CompileAndBuild(
281 const xrt::XLAComputation& xrt_computation_proto,
282 const XLA_TpuMeshState* mesh_state,
283 TpuProgramGroupInterface* tpu_program_group_interface) {
284 se_tpu::SerializedProto serialized_compilation_request =
285 se_tpu::SerializeProto(xrt_computation_proto);
286 auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
287 se_tpu::SerializedProto_Free(serialized_compilation_request);
288 });
289 size_t count = 0;
290 XLA_TpuProgram** xla_tpu_programs = nullptr;
291 StatusHelper status;
292 OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
293 mesh_state, &xla_tpu_programs,
294 &count, status.c_status);
295 if (!status.ok()) {
296 VLOG(1) << "Run CompileAndBuild failed.";
297 return status.status();
298 }
299
300 // SPMD could return 1 result for all partitions.
301 int num_cores_per_replica =
302 xrt_computation_proto.config().num_cores_per_replica()
303 ? xrt_computation_proto.config().num_cores_per_replica()
304 : 1;
305 TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
306 VLOG(1) << "Initialize TpuProgramGroup.";
307 TpuProgramGroup* tpu_program_group =
308 tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
309 tpu_program_group->Initialize(
310 absl::MakeConstSpan(&xla_tpu_programs[0], count));
311 OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
312 return status.status();
313 }
314
tpu_programs(TpuProgramShardingType sharding_type) const315 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
316 TpuProgramShardingType sharding_type) const {
317 std::vector<XLA_TpuProgram*> tpu_programs;
318 tpu_programs.reserve(tpu_programs_.size());
319 for (size_t i = 0; i < tpu_programs_.size(); ++i) {
320 if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
321 tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
322 tpu_programs_[i], sharding_type));
323 CHECK_NE(tpu_programs[i], nullptr);
324 }
325 }
326 return tpu_programs;
327 }
328
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)329 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
330 const std::vector<TpuSerializedProto>& rpc_response_protos) {
331 std::vector<XLA_TpuProgram*> tpu_programs;
332 tpu_programs.resize(rpc_response_protos.size());
333
334 for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
335 StatusHelper status;
336 auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
337 OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
338 rpc_response_protos[i], xla_tpu_program, status.c_status);
339 if (!status.status().ok()) {
340 OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
341 return status.status();
342 }
343 tpu_programs[i] = xla_tpu_program;
344 }
345
346 Initialize(tpu_programs);
347 return Status::OK();
348 }
349
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const350 Status TpuProgramGroup::SerializeExecutable(
351 int index, TpuExecutableSerializedProto* executable) const {
352 CHECK_GE(index, 0);
353 CHECK_LT(index, tpu_programs_.size());
354 StatusHelper status;
355 OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
356 executable, status.c_status);
357 return status.status();
358 }
359
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const360 Status TpuProgramGroup::SerializeCompilerMetadata(
361 int index, CompilerMetadataSerializedProto* compiler_metadata) const {
362 CHECK_GE(index, 0);
363 CHECK_LT(index, tpu_programs_.size());
364 StatusHelper status;
365 OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
366 tpu_programs_[index], compiler_metadata, status.c_status);
367 return status.status();
368 }
369 } // namespace tpu
370 } // namespace tensorflow
371