• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
16 
17 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
18 #include "tensorflow/compiler/xla/xla.pb.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/platform/casts.h"
21 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
24 #include "tensorflow/core/tpu/tpu_api.h"
25 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28 
29 namespace tensorflow {
30 namespace tpu {
31 namespace {
32 namespace se_tpu = ::stream_executor::tpu;
33 using stream_executor::port::Status;
34 }  // namespace
35 
ConstructExecutableInfo(const XLA_TpuProgram * xla_tpu_program)36 TPUExecutableInfoProto TpuProgramGroup::ConstructExecutableInfo(
37     const XLA_TpuProgram* xla_tpu_program) {
38   VLOG(1) << "ConstructExecutableInfo";
39   TpuSerializedProto serialized_executable_info = {};
40   StatusHelper status;
41   OpsApiFn()->TpuProgram_GetExecutableInfoFn(
42       xla_tpu_program, &serialized_executable_info, status.c_status);
43   TPUExecutableInfoProto executable_info;
44   if (status.ok()) {
45     executable_info = se_tpu::DeserializeProto<TPUExecutableInfoProto>(
46         serialized_executable_info);
47     StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
48   }
49   return executable_info;
50 }
51 
ConstructHostTransferInfo(const XLA_TpuProgram * xla_tpu_program)52 TPUHostTransferInfoProto TpuProgramGroup::ConstructHostTransferInfo(
53     const XLA_TpuProgram* xla_tpu_program) {
54   VLOG(1) << "ConstructHostTransferInfo";
55   TpuSerializedProto serialized_host_transfer_info = {};
56   StatusHelper status;
57   OpsApiFn()->TpuProgram_GetHostTransferInfoFn(
58       xla_tpu_program, &serialized_host_transfer_info, status.c_status);
59   TPUHostTransferInfoProto host_transfer_info;
60   if (status.ok()) {
61     host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
62         serialized_host_transfer_info);
63     StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
64   }
65   return host_transfer_info;
66 }
67 
ConstructHloMetadata(const XLA_TpuProgram * xla_tpu_program)68 xla::HloProto TpuProgramGroup::ConstructHloMetadata(
69     const XLA_TpuProgram* xla_tpu_program) {
70   VLOG(1) << "ConstructHloMetadata";
71   TpuSerializedProto serialized_hlo_metadata = {};
72   StatusHelper status;
73   OpsApiFn()->TpuProgram_GetHloMetadataFn(
74       xla_tpu_program, &serialized_hlo_metadata, status.c_status);
75   xla::HloProto hlo_metadata;
76   if (status.ok()) {
77     hlo_metadata =
78         se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
79     StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
80   }
81   return hlo_metadata;
82 }
83 
Initialize(absl::Span<XLA_TpuProgram * const> xla_tpu_programs)84 void TpuProgramGroup::Initialize(
85     absl::Span<XLA_TpuProgram* const> xla_tpu_programs) {
86   CHECK_GT(xla_tpu_programs.size(), 0);
87   CHECK_EQ(program_count(), 0) << "Reinitialization of an existing "
88                                   "`TpuProgramGroup` instance is prohibited.";
89   set_tpu_programs(xla_tpu_programs);
90 
91   CHECK_EQ(tpu_program_fingerprints_.size(), 0);
92   set_fingerprints();
93 
94   std::vector<bool> may_modify_variables_array(tpu_programs_.size(), false);
95   std::vector<TPUExecutableInfoProto> executable_infos(tpu_programs_.size());
96   std::vector<TPUHostTransferInfoProto> host_transfer_infos(
97       tpu_programs_.size());
98   std::vector<xla::HloProto> hlo_metadatas(tpu_programs_.size());
99   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
100     const XLA_TpuProgram* xla_tpu_program = tpu_programs_[i];
101     bool may_modify_variables;
102     OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(xla_tpu_program,
103                                                    &may_modify_variables);
104     may_modify_variables_array[i] = may_modify_variables;
105     executable_infos[i] = ConstructExecutableInfo(xla_tpu_program);
106     host_transfer_infos[i] = ConstructHostTransferInfo(xla_tpu_program);
107     hlo_metadatas[i] = ConstructHloMetadata(xla_tpu_program);
108   }
109 
110   may_modify_variables_ = may_modify_variables_array;
111   executable_infos_ = executable_infos;
112   host_transfer_infos_ = host_transfer_infos;
113   hlo_metadatas_ = hlo_metadatas;
114   RefreshHloMetadatasPtrs();
115 }
116 
has_sharding_program() const117 bool TpuProgramGroup::has_sharding_program() const {
118   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
119     if (!OpsApiFn()->TpuProgram_HasShardingFn(tpu_program)) {
120       return false;
121     }
122   }
123   return true;
124 }
125 
program_count() const126 size_t TpuProgramGroup::program_count() const { return tpu_programs_.size(); }
127 
program_size() const128 int64_t TpuProgramGroup::program_size() const {
129   int64_t total_size = 0;
130   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
131     total_size += OpsApiFn()->TpuProgram_GetProgramSizeFn(tpu_program);
132   }
133   return total_size;
134 }
135 
LogProgramMemorySummary()136 bool TpuProgramGroup::LogProgramMemorySummary() {
137   bool success = true;
138   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
139     success &= OpsApiFn()->TpuProgram_LogProgramMemorySummaryFn(tpu_program);
140   }
141   return success;
142 }
143 
UnloadAndDestroyPrograms()144 void TpuProgramGroup::UnloadAndDestroyPrograms() {
145   for (XLA_TpuProgram* tpu_program : tpu_programs_) {
146     StatusHelper status;
147     OpsApiFn()->TpuProgram_UnloadAndDestroyFn(tpu_program, status.c_status);
148     auto s = status.status();
149     if (!s.ok()) {
150       LOG(ERROR) << "TpuProgramGroup::UnloadPrograms(): " << s.ToString();
151     }
152   }
153   tpu_programs_.clear();
154 }
155 
TpuProgramGroup(TpuProgramGroup && other)156 TpuProgramGroup::TpuProgramGroup(TpuProgramGroup&& other)
157     : may_modify_variables_(std::move(other.may_modify_variables_)),
158       tpu_programs_(std::move(other.tpu_programs_)),
159       executable_infos_(std::move(other.executable_infos_)),
160       host_transfer_infos_(std::move(other.host_transfer_infos_)),
161       hlo_metadatas_(std::move(other.hlo_metadatas_)) {
162   RefreshHloMetadatasPtrs();
163 }
164 
set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas)165 void TpuProgramGroup::set_hlo_metadatas(
166     absl::Span<const xla::HloProto> hlo_metadatas) {
167   hlo_metadatas_.resize(hlo_metadatas.size());
168   for (size_t i = 0; i < hlo_metadatas.size(); ++i) {
169     hlo_metadatas_[i] = hlo_metadatas[i];
170   }
171   RefreshHloMetadatasPtrs();
172 }
173 
hlo_metadatas() const174 absl::Span<const xla::HloProto* const> TpuProgramGroup::hlo_metadatas() const {
175   return hlo_metadatas_ptrs_;
176 }
177 
hlo_metadata(int index) const178 const xla::HloProto* TpuProgramGroup::hlo_metadata(int index) const {
179   CHECK_GE(index, 0);
180   CHECK_LT(index, hlo_metadatas_ptrs_.size());
181   return hlo_metadatas_ptrs_[index];
182 }
183 
RefreshHloMetadatasPtrs()184 void TpuProgramGroup::RefreshHloMetadatasPtrs() {
185   hlo_metadatas_ptrs_.reserve(hlo_metadatas_.size());
186   for (const auto& hlo_metadata_internal_ : hlo_metadatas_) {
187     hlo_metadatas_ptrs_.push_back(&hlo_metadata_internal_);
188   }
189 }
190 
LogCompilationStats(const TpuCompilationCacheKey & key,absl::Duration duration)191 Status TpuProgramGroup::LogCompilationStats(const TpuCompilationCacheKey& key,
192                                             absl::Duration duration) {
193   // A placeholder for tracking compilation statistics for future work. The
194   // implementation can be pushing into some external storage for analytics.
195   return Status::OK();
196 }
197 
may_modify_variables_list() const198 const std::vector<bool>& TpuProgramGroup::may_modify_variables_list() const {
199   return may_modify_variables_;
200 }
201 
set_may_modify_variables(const std::vector<bool> & may_modify_variables)202 void TpuProgramGroup::set_may_modify_variables(
203     const std::vector<bool>& may_modify_variables) {
204   may_modify_variables_ = may_modify_variables;
205 }
206 
may_modify_variables(int index) const207 bool TpuProgramGroup::may_modify_variables(int index) const {
208   CHECK_GE(index, 0);
209   CHECK_LT(index, tpu_programs_.size());
210   bool may_modify_variables;
211   OpsApiFn()->TpuProgram_GetMayModifyVariablesFn(tpu_programs_[index],
212                                                  &may_modify_variables);
213   return may_modify_variables;
214 }
215 
tpu_programs() const216 const std::vector<XLA_TpuProgram*>& TpuProgramGroup::tpu_programs() const {
217   return tpu_programs_;
218 }
219 
fingerprints() const220 const std::vector<std::string>& TpuProgramGroup::fingerprints() const {
221   return tpu_program_fingerprints_;
222 }
223 
set_fingerprints()224 void TpuProgramGroup::set_fingerprints() {
225   for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
226     TpuProgramFingerprint fingerprint =
227         OpsApiFn()->TpuProgram_GetFingerprintFn(tpu_program);
228     tpu_program_fingerprints_.emplace_back(
229         std::string(fingerprint.bytes, fingerprint.size));
230     OpsApiFn()->TpuProgram_DestroyFingerprintFn(fingerprint);
231   }
232 }
233 
fingerprint(int index) const234 const std::string& TpuProgramGroup::fingerprint(int index) const {
235   return fingerprints().at(index);
236 }
237 
tpu_program(int index) const238 const XLA_TpuProgram* TpuProgramGroup::tpu_program(int index) const {
239   CHECK_GE(index, 0);
240   CHECK_LT(index, tpu_programs_.size());
241   return tpu_programs_[index];
242 }
243 
set_tpu_programs(absl::Span<XLA_TpuProgram * const> tpu_programs)244 void TpuProgramGroup::set_tpu_programs(
245     absl::Span<XLA_TpuProgram* const> tpu_programs) {
246   tpu_programs_.resize(tpu_programs.size());
247   for (size_t i = 0; i < tpu_programs.size(); ++i) {
248     tpu_programs_[i] = tpu_programs[i];
249   }
250 }
251 
executable_info(int index) const252 const TPUExecutableInfoProto& TpuProgramGroup::executable_info(
253     int index) const {
254   CHECK_GE(index, 0);
255   CHECK_LT(index, executable_infos_.size());
256   return executable_infos_[index];
257 }
258 
host_transfer_info(int index) const259 const TPUHostTransferInfoProto& TpuProgramGroup::host_transfer_info(
260     int index) const {
261   CHECK_GE(index, 0);
262   CHECK_LT(index, host_transfer_infos_.size());
263   return host_transfer_infos_[index];
264 }
265 
266 /*static*/
CompileAndBuild(const TpuCompilationRequestProto & compilation_request,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)267 Status TpuProgramGroup::CompileAndBuild(
268     const TpuCompilationRequestProto& compilation_request,
269     const XLA_TpuMeshState* mesh_state,
270     TpuProgramGroupInterface* tpu_program_group_interface) {
271   se_tpu::SerializedProto serialized_compilation_request =
272       se_tpu::SerializeProto(compilation_request);
273   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
274     se_tpu::SerializedProto_Free(serialized_compilation_request);
275   });
276   size_t count = 0;
277   XLA_TpuProgram** xla_tpu_programs = nullptr;
278   StatusHelper status;
279   OpsApiFn()->TpuCompile_CompileAndBuildFn(serialized_compilation_request,
280                                            mesh_state, &xla_tpu_programs,
281                                            &count, status.c_status);
282   if (!status.ok()) {
283     VLOG(1) << "Run CompileAndBuild failed.";
284     return status.status();
285   }
286 
287   // SPMD could return 1 result for all partitions.
288   TF_RET_CHECK(count == 1 ||
289                count == compilation_request.metadata().num_cores_per_replica());
290 
291   VLOG(1) << "Initialize TpuProgramGroup.";
292   TpuProgramGroup* tpu_program_group =
293       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
294   tpu_program_group->Initialize(
295       absl::MakeConstSpan(&xla_tpu_programs[0], count));
296   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
297   return status.status();
298 }
299 
300 /*static*/
CompileAndBuild(const xrt::XLAComputation & xrt_computation_proto,const XLA_TpuMeshState * mesh_state,TpuProgramGroupInterface * tpu_program_group_interface)301 Status TpuProgramGroup::CompileAndBuild(
302     const xrt::XLAComputation& xrt_computation_proto,
303     const XLA_TpuMeshState* mesh_state,
304     TpuProgramGroupInterface* tpu_program_group_interface) {
305   se_tpu::SerializedProto serialized_compilation_request =
306       se_tpu::SerializeProto(xrt_computation_proto);
307   auto cleanup = gtl::MakeCleanup([serialized_compilation_request] {
308     se_tpu::SerializedProto_Free(serialized_compilation_request);
309   });
310   size_t count = 0;
311   XLA_TpuProgram** xla_tpu_programs = nullptr;
312   StatusHelper status;
313   OpsApiFn()->TpuCompile_XrtCompileAndBuildFn(serialized_compilation_request,
314                                               mesh_state, &xla_tpu_programs,
315                                               &count, status.c_status);
316   if (!status.ok()) {
317     VLOG(1) << "Run CompileAndBuild failed.";
318     return status.status();
319   }
320 
321   // SPMD could return 1 result for all partitions.
322   int num_cores_per_replica =
323       xrt_computation_proto.config().num_cores_per_replica()
324           ? xrt_computation_proto.config().num_cores_per_replica()
325           : 1;
326   TF_RET_CHECK(count == 1 || count == num_cores_per_replica);
327   VLOG(1) << "Initialize TpuProgramGroup.";
328   TpuProgramGroup* tpu_program_group =
329       tensorflow::down_cast<TpuProgramGroup*>(tpu_program_group_interface);
330   tpu_program_group->Initialize(
331       absl::MakeConstSpan(&xla_tpu_programs[0], count));
332   OpsApiFn()->TpuProgram_FreeArrayFn(xla_tpu_programs);
333   return status.status();
334 }
335 
tpu_programs(TpuProgramShardingType sharding_type) const336 std::vector<XLA_TpuProgram*> TpuProgramGroup::tpu_programs(
337     TpuProgramShardingType sharding_type) const {
338   std::vector<XLA_TpuProgram*> tpu_programs;
339   tpu_programs.reserve(tpu_programs_.size());
340   for (size_t i = 0; i < tpu_programs_.size(); ++i) {
341     if (OpsApiFn()->TpuProgram_HasShardingFn(tpu_programs_[i])) {
342       tpu_programs.push_back(OpsApiFn()->TpuProgram_GetTpuProgramFn(
343           tpu_programs_[i], sharding_type));
344       CHECK_NE(tpu_programs[i], nullptr);
345     }
346   }
347   return tpu_programs;
348 }
349 
DeserializeFromRpcResponseProtos(const std::vector<TpuSerializedProto> & rpc_response_protos)350 Status TpuProgramGroup::DeserializeFromRpcResponseProtos(
351     const std::vector<TpuSerializedProto>& rpc_response_protos) {
352   std::vector<XLA_TpuProgram*> tpu_programs;
353   tpu_programs.resize(rpc_response_protos.size());
354 
355   for (size_t i = 0; i < rpc_response_protos.size(); ++i) {
356     StatusHelper status;
357     auto* xla_tpu_program = OpsApiFn()->TpuProgram_NewFn();
358     OpsApiFn()->TpuProgram_DeserializeFromGetTpuProgramResponseProtoFn(
359         rpc_response_protos[i], xla_tpu_program, status.c_status);
360     if (!status.status().ok()) {
361       OpsApiFn()->TpuProgram_FreeFn(xla_tpu_program);
362       return status.status();
363     }
364     tpu_programs[i] = xla_tpu_program;
365   }
366 
367   Initialize(tpu_programs);
368   return Status::OK();
369 }
370 
SerializeExecutable(int index,TpuExecutableSerializedProto * executable) const371 Status TpuProgramGroup::SerializeExecutable(
372     int index, TpuExecutableSerializedProto* executable) const {
373   CHECK_GE(index, 0);
374   CHECK_LT(index, tpu_programs_.size());
375   StatusHelper status;
376   OpsApiFn()->TpuProgram_SerializeTpuExecutableFn(tpu_programs_[index],
377                                                   executable, status.c_status);
378   return status.status();
379 }
380 
SerializeCompilerMetadata(int index,CompilerMetadataSerializedProto * compiler_metadata) const381 Status TpuProgramGroup::SerializeCompilerMetadata(
382     int index, CompilerMetadataSerializedProto* compiler_metadata) const {
383   CHECK_GE(index, 0);
384   CHECK_LT(index, tpu_programs_.size());
385   StatusHelper status;
386   OpsApiFn()->TpuProgram_SerializeCompilerMetadataFn(
387       tpu_programs_[index], compiler_metadata, status.c_status);
388   return status.status();
389 }
390 }  // namespace tpu
391 }  // namespace tensorflow
392