• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
16 
17 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
18 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
19 #include "tensorflow/core/platform/casts.h"
20 #if defined(LIBTPU_ON_GCE)
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
22 #endif
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
24 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
25 #include "tensorflow/stream_executor/tpu/proto_helper.h"
26 
27 namespace tensorflow {
28 namespace tpu {
CreateChannelCredentials()29 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
30   return ::grpc::InsecureChannelCredentials();  // NOLINT
31 }
32 
33 #if defined(LIBTPU_ON_GCE)
34 template <>
DeserializeRpcResponseToCacheEntry(absl::string_view local_proto_key,GetTpuProgramResponseExternal * response,std::shared_ptr<CacheEntry> * cache_entry)35 Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
36     absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
37     std::shared_ptr<CacheEntry>* cache_entry) {
38   CHECK_NE(response, nullptr);
39   CHECK_NE(cache_entry, nullptr);
40   *cache_entry = std::make_shared<CacheEntry>();
41   CacheEntry& entry = **cache_entry;
42   entry.key = std::string(local_proto_key);
43 
44   if (response->is_empty()) {
45     entry.size = 0;
46   } else {
47     TpuSerializedProto serialized_response_proto =
48         stream_executor::tpu::SerializeProto(*response);
49     auto cleanup = xla::MakeCleanup([&serialized_response_proto]() {
50       stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
51     });
52     // When we lookup from remote cache, we fetch a TPU program for a specific
53     // core, hence we allocate TPU program group for a single program.
54     auto tpu_program_group = absl::make_unique<TpuProgramGroup>();
55 
56     // TODO(b/166575150): can be optimized by sending the buffer over the gRPC
57     // without an extra deserializing.
58     TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromRpcResponseProtos(
59         {serialized_response_proto}));
60     entry.tpu_program_group = std::move(tpu_program_group);
61     entry.size = entry.tpu_program_group->program_size();
62   }
63 
64   return Status::OK();
65 }
66 
SerializeCacheEntryToBufferSlices(const TpuCompilationCacheEntry & cache_entry)67 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
68     const TpuCompilationCacheEntry& cache_entry) {
69   if (cache_entry.tpu_program_group() == nullptr) {
70     // It's possible that the sharding/unsharding entry does not exist, but the
71     // main entry must exist.
72     GetTpuProgramResponseExternal header;
73     header.set_is_empty(true);
74     std::string encoded_header;
75     if (!header.AppendToString(&encoded_header)) {
76       return errors::Internal("Failed to serialize TPU program metadata.");
77     }
78     ::grpc::Slice slice(encoded_header);
79     return std::vector<::grpc::Slice>{slice};
80   }
81 
82   const TpuProgramGroup* tpu_program_group =
83       tensorflow::down_cast<const TpuProgramGroup*>(
84           cache_entry.tpu_program_group());
85   CHECK_NE(tpu_program_group, nullptr);
86   CHECK_GE(tpu_program_group->program_count(), 0);
87   CHECK_GE(cache_entry.core_index(), 0);
88   CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
89   const int64 program_size = tpu_program_group->program_size();
90   if (program_size > INT_MAX) {
91     return errors::Internal("TPU program exceeded 2 GiB.");
92   }
93 
94   TpuExecutableSerializedProto executable;
95   auto cleanup_executable = xla::MakeCleanup([&executable]() {
96     if (executable.size > 0) {
97       stream_executor::tpu::SerializedProto_Free(executable);
98     }
99   });
100   auto get_executable_status = tpu_program_group->SerializeExecutable(
101       cache_entry.core_index(), &executable);
102   if (!get_executable_status.ok()) {
103     return errors::Internal("Failed to serialize TPU program.");
104   }
105 
106   // Encode and serialize header fields.
107   GetTpuProgramResponseExternal header;
108   if (!header.mutable_proto()->ParseFromArray(executable.bytes,
109                                               executable.size)) {
110     return errors::Internal("Failed to serialize TPU program.");
111   }
112   header.set_is_empty(false);
113 
114 
115   bool may_modify_variables =
116       tpu_program_group->may_modify_variables(cache_entry.core_index());
117   header.set_may_modify_variables(may_modify_variables);
118 
119   CompilerMetadataSerializedProto compiler_metadata;
120   auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() {
121     if (compiler_metadata.size > 0) {
122       stream_executor::tpu::SerializedProto_Free(compiler_metadata);
123     }
124   });
125   Status get_compiler_metadata_status =
126       tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
127                                                    &compiler_metadata);
128   if (!get_compiler_metadata_status.ok()) {
129     return errors::Internal("Failed to serialize compiler metadata.");
130   }
131   if (!header.mutable_compiler_metadata()->ParseFromArray(
132           compiler_metadata.bytes, compiler_metadata.size)) {
133     return errors::Internal("Failed to deserialize compiler metadata.");
134   }
135   std::string encoded_header;
136   if (!header.AppendToString(&encoded_header)) {
137     return errors::Internal("Failed to serialize TPU program metadata.");
138   }
139 
140   return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
141 }
142 #endif  // LIBTPU_ON_GCE
143 }  // namespace tpu
144 }  // namespace tensorflow
145