• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "grpcpp/security/credentials.h"
24 #include "grpcpp/support/slice.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
31 
32 namespace tensorflow {
33 namespace tpu {
34 
35 // A cache entry for remote TPU compilation.
36 struct CacheEntry {
CacheEntryCacheEntry37   CacheEntry() : size(0), last_use(-1) {}
~CacheEntryCacheEntry38   virtual ~CacheEntry() {
39     if (tpu_program_group != nullptr) {
40       tpu_program_group->UnloadAndDestroyPrograms();
41     }
42   }
43   std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
44   std::string key;
45   int64 size;
46 
47   // An integer-based monotonically increasing counter used by the TPU
48   // compilation cache to sort and evict the least recently used entry when the
49   // cache size exceeded the maximum size limit. The value is initialized to
50   // `-1` as an initial value.
51   int64 last_use;
52 };
53 
54 // Implementation of `CompilationCacheEntryRef` that holds a shared_ptr to the
55 // local cache entry until the wrapper is destroyed.
56 class CacheWrapper : public CompilationCacheEntryRef {
57  public:
CacheWrapper(std::shared_ptr<CacheEntry> entry)58   explicit CacheWrapper(std::shared_ptr<CacheEntry> entry)
59       : cache_entry_(std::move(entry)) {}
60   ~CacheWrapper() override = default;
61 
get()62   TpuCompilationCacheEntry get() override {
63     if (cache_entry_->size == 0) {
64       // Create an empty entry if the size is 0. This corresponds to
65       // non-existing sharding/unsharding entries.
66       return TpuCompilationCacheEntry();
67     }
68     return TpuCompilationCacheEntry(cache_entry_->tpu_program_group.get(),
69                                     /*core_index=*/0);
70   }
71 
ToSubEntryRef(CompilationCacheFetchTarget fetch_target)72   Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target) override {
73     LOG(FATAL) << "Not implemented by designed.";
74   }
75 
76  private:
77   std::shared_ptr<CacheEntry> cache_entry_;
78 };
79 
80 // Creates gRPC channel credentials for the current runtime env.
81 std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
82 
83 // Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
84 // `cache_entry` will be instantiated by the function.
85 template <typename ResponseType>
86 Status DeserializeRpcResponseToCacheEntry(
87     const absl::string_view local_proto_key, ResponseType* response,
88     std::shared_ptr<CacheEntry>* cache_entry);
89 
90 // Serializes `TpuCompilationCacheEntry` to gRPC bufer slices.
91 xla::StatusOr<std::vector<::grpc::Slice>> SerializeCacheEntryToBufferSlices(
92     const TpuCompilationCacheEntry& cache_entry);
93 }  // namespace tpu
94 }  // namespace tensorflow
95 
96 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SUPPORT_H_
97