1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "grpcpp/security/credentials.h" 24 #include "grpcpp/support/slice.h" 25 #include "absl/strings/string_view.h" 26 #include "tensorflow/core/platform/status.h" 27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" 28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" 29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" 30 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 31 32 namespace tensorflow { 33 namespace tpu { 34 35 // A cache entry for remote TPU compilation. 36 struct CacheEntry { CacheEntryCacheEntry37 CacheEntry() : size(0), last_use(-1) {} ~CacheEntryCacheEntry38 virtual ~CacheEntry() { 39 if (tpu_program_group != nullptr) { 40 tpu_program_group->UnloadAndDestroyPrograms(); 41 } 42 } 43 std::unique_ptr<TpuProgramGroupInterface> tpu_program_group; 44 std::string key; 45 int64 size; 46 47 // An integer-based monotonically increasing counter used by the TPU 48 // compilation cache to sort and evict the least recently used entry when the 49 // cache size exceeded the maximum size limit. The value is initialized to 50 // `-1` as an initial value. 51 int64 last_use; 52 }; 53 54 // Implementation of `CompilationCacheEntryRef` that holds a shared_ptr to the 55 // local cache entry until the wrapper is destroyed. 56 class CacheWrapper : public CompilationCacheEntryRef { 57 public: CacheWrapper(std::shared_ptr<CacheEntry> entry)58 explicit CacheWrapper(std::shared_ptr<CacheEntry> entry) 59 : cache_entry_(std::move(entry)) {} 60 ~CacheWrapper() override = default; 61 get()62 TpuCompilationCacheEntry get() override { 63 if (cache_entry_->size == 0) { 64 // Create an empty entry if the size is 0. This corresponds to 65 // non-existing sharding/unsharding entries. 66 return TpuCompilationCacheEntry(); 67 } 68 return TpuCompilationCacheEntry(cache_entry_->tpu_program_group.get(), 69 /*core_index=*/0); 70 } 71 ToSubEntryRef(CompilationCacheFetchTarget fetch_target)72 Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override { 73 LOG(FATAL) << "Not implemented by designed."; 74 } 75 76 private: 77 std::shared_ptr<CacheEntry> cache_entry_; 78 }; 79 80 // Creates gRPC channel credentials for the current runtime env. 81 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials(); 82 83 // Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The 84 // `cache_entry` will be instantiated by the function. 85 template <typename ResponseType> 86 Status DeserializeRpcResponseToCacheEntry( 87 const absl::string_view local_proto_key, ResponseType* response, 88 std::shared_ptr<CacheEntry>* cache_entry); 89 90 // Serializes `TpuCompilationCacheEntry` to gRPC bufer slices. 91 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices( 92 const TpuCompilationCacheEntry& cache_entry); 93 } // namespace tpu 94 } // namespace tensorflow 95 96 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_ 97