1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
16
17 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
18 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
19 #include "tensorflow/core/platform/casts.h"
20 #if defined(LIBTPU_ON_GCE)
21 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
22 #endif
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
24 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
25 #include "tensorflow/stream_executor/tpu/proto_helper.h"
26
27 namespace tensorflow {
28 namespace tpu {
CreateChannelCredentials()29 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
30 return ::grpc::InsecureChannelCredentials(); // NOLINT
31 }
32
33 #if defined(LIBTPU_ON_GCE)
34 template <>
DeserializeRpcResponseToCacheEntry(absl::string_view local_proto_key,GetTpuProgramResponseExternal * response,std::shared_ptr<CacheEntry> * cache_entry)35 Status DeserializeRpcResponseToCacheEntry<GetTpuProgramResponseExternal>(
36 absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
37 std::shared_ptr<CacheEntry>* cache_entry) {
38 CHECK_NE(response, nullptr);
39 CHECK_NE(cache_entry, nullptr);
40 *cache_entry = std::make_shared<CacheEntry>();
41 CacheEntry& entry = **cache_entry;
42 entry.key = std::string(local_proto_key);
43
44 if (response->is_empty()) {
45 entry.size = 0;
46 } else {
47 TpuSerializedProto serialized_response_proto =
48 stream_executor::tpu::SerializeProto(*response);
49 auto cleanup = xla::MakeCleanup([&serialized_response_proto]() {
50 stream_executor::tpu::SerializedProto_Free(serialized_response_proto);
51 });
52 // When we lookup from remote cache, we fetch a TPU program for a specific
53 // core, hence we allocate TPU program group for a single program.
54 auto tpu_program_group = absl::make_unique<TpuProgramGroup>();
55
56 // TODO(b/166575150): can be optimized by sending the buffer over the gRPC
57 // without an extra deserializing.
58 TF_RETURN_IF_ERROR(tpu_program_group->DeserializeFromRpcResponseProtos(
59 {serialized_response_proto}));
60 entry.tpu_program_group = std::move(tpu_program_group);
61 entry.size = entry.tpu_program_group->program_size();
62 }
63
64 return Status::OK();
65 }
66
SerializeCacheEntryToBufferSlices(const TpuCompilationCacheEntry & cache_entry)67 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
68 const TpuCompilationCacheEntry& cache_entry) {
69 if (cache_entry.tpu_program_group() == nullptr) {
70 // It's possible that the sharding/unsharding entry does not exist, but the
71 // main entry must exist.
72 GetTpuProgramResponseExternal header;
73 header.set_is_empty(true);
74 std::string encoded_header;
75 if (!header.AppendToString(&encoded_header)) {
76 return errors::Internal("Failed to serialize TPU program metadata.");
77 }
78 ::grpc::Slice slice(encoded_header);
79 return std::vector<::grpc::Slice>{slice};
80 }
81
82 const TpuProgramGroup* tpu_program_group =
83 tensorflow::down_cast<const TpuProgramGroup*>(
84 cache_entry.tpu_program_group());
85 CHECK_NE(tpu_program_group, nullptr);
86 CHECK_GE(tpu_program_group->program_count(), 0);
87 CHECK_GE(cache_entry.core_index(), 0);
88 CHECK_LT(cache_entry.core_index(), tpu_program_group->program_count());
89 const int64 program_size = tpu_program_group->program_size();
90 if (program_size > INT_MAX) {
91 return errors::Internal("TPU program exceeded 2 GiB.");
92 }
93
94 TpuExecutableSerializedProto executable;
95 auto cleanup_executable = xla::MakeCleanup([&executable]() {
96 if (executable.size > 0) {
97 stream_executor::tpu::SerializedProto_Free(executable);
98 }
99 });
100 auto get_executable_status = tpu_program_group->SerializeExecutable(
101 cache_entry.core_index(), &executable);
102 if (!get_executable_status.ok()) {
103 return errors::Internal("Failed to serialize TPU program.");
104 }
105
106 // Encode and serialize header fields.
107 GetTpuProgramResponseExternal header;
108 if (!header.mutable_proto()->ParseFromArray(executable.bytes,
109 executable.size)) {
110 return errors::Internal("Failed to serialize TPU program.");
111 }
112 header.set_is_empty(false);
113
114
115 bool may_modify_variables =
116 tpu_program_group->may_modify_variables(cache_entry.core_index());
117 header.set_may_modify_variables(may_modify_variables);
118
119 CompilerMetadataSerializedProto compiler_metadata;
120 auto cleanup_compiler_metadata = xla::MakeCleanup([&compiler_metadata]() {
121 if (compiler_metadata.size > 0) {
122 stream_executor::tpu::SerializedProto_Free(compiler_metadata);
123 }
124 });
125 Status get_compiler_metadata_status =
126 tpu_program_group->SerializeCompilerMetadata(cache_entry.core_index(),
127 &compiler_metadata);
128 if (!get_compiler_metadata_status.ok()) {
129 return errors::Internal("Failed to serialize compiler metadata.");
130 }
131 if (!header.mutable_compiler_metadata()->ParseFromArray(
132 compiler_metadata.bytes, compiler_metadata.size)) {
133 return errors::Internal("Failed to deserialize compiler metadata.");
134 }
135 std::string encoded_header;
136 if (!header.AppendToString(&encoded_header)) {
137 return errors::Internal("Failed to serialize TPU program metadata.");
138 }
139
140 return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)};
141 }
142 #endif // LIBTPU_ON_GCE
143 } // namespace tpu
144 } // namespace tensorflow
145