• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/jit/shape_inference.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/platform/fingerprint.h"
28 #include "tensorflow/core/platform/strcat.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
31 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
33 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
34 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
35 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
36 
37 namespace tensorflow {
38 namespace tpu {
39 // Forward declaration, defined below.
40 class TpuCompileOpKernelCommon;
41 
42 // A base factory class for creating a `TpuCompileOpKernelImpl` variant.
43 // By design, the actual factory can only be set once.
44 class CompileOpImplFactory {
45  public:
46   virtual ~CompileOpImplFactory() = default;
47 
48   virtual stream_executor::port::StatusOr<
49       std::unique_ptr<TpuCompileOpKernelCommon>>
50   CreateNonMlirImpl(OpKernelConstruction* ctx) = 0;
51 
52   virtual stream_executor::port::StatusOr<
53       std::unique_ptr<TpuCompileOpKernelCommon>>
54   CreateMlirImpl(OpKernelConstruction* ctx) = 0;
55 
56   static CompileOpImplFactory* Get();
57   static void Register(CompileOpImplFactory* factory);
58 
59  private:
60   static CompileOpImplFactory* factory_;
61 };
62 
63 // Abstract base class for TpuCompileOpKernel implementation.
64 class TpuCompileOpKernelCommon {
65  public:
TpuCompileOpKernelCommon(const std::string & mlir_module,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close)66   TpuCompileOpKernelCommon(const std::string& mlir_module,
67                            const tpu::TPUCompileMetadataProto metadata,
68                            int num_computations, bool return_hlo_protos,
69                            bool unload_cache_on_session_close)
70       : metadata_(metadata),
71         use_mlir_(true),
72         mlir_module_(mlir_module),
73         num_computations_(num_computations),
74         return_hlo_protos_(return_hlo_protos),
75         unload_cache_entry_on_session_close_(unload_cache_on_session_close),
76         persistent_cache_(nullptr) {
77     mlir_module_fingerprint_ = tensorflow::Fingerprint64(mlir_module_);
78   }
79 
TpuCompileOpKernelCommon(const NameAttrList & function,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close,std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)80   TpuCompileOpKernelCommon(
81       const NameAttrList& function, const tpu::TPUCompileMetadataProto metadata,
82       int num_computations, bool return_hlo_protos,
83       bool unload_cache_on_session_close,
84       std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)
85       : metadata_(metadata),
86         use_mlir_(false),
87         function_(function),
88         num_computations_(num_computations),
89         return_hlo_protos_(return_hlo_protos),
90         unload_cache_entry_on_session_close_(unload_cache_on_session_close),
91         persistent_cache_(std::move(persistent_cache)) {}
92 
93   virtual ~TpuCompileOpKernelCommon() = default;
94 
95   void Compute(OpKernelContext* ctx);
96 
97   // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler
98   // compiles into TPU programs ready for execution.
99   virtual Status Compile(
100       const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
101       const XLA_TpuMeshState* mesh_state,
102       const std::vector<TensorShape>& arg_shapes,
103       TpuProgramGroupInterface* tpu_program_group) = 0;
104 
105   // Performs shape inference on `computation`, filling shape_info with operator
106   // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
107   static Status RunShapeInferenceOnComputation(
108       const tpu::TPUCompileMetadataProto& metadata,
109       const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
110       FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info);
111 
112  protected:
113   Status ComputeInternal(OpKernelContext* ctx);
114 
115   // Compile TPU program locally and populate the host compilation cache.
116   Status CompileLocallyAndFillHostCache(
117       FunctionLibraryRuntime* flib_runtime,
118       const SessionMetadata* session_metadata,
119       const TpuMeshStateInterface* mesh_state,
120       const std::vector<TensorShape>& dynamic_shapes,
121       const OpInputList& guaranteed_constants,
122       const tpu::TpuCompilationCacheKey& key,
123       TpuProgramGroupInterface* tpu_program_group);
124 
125   // Lookup from persistent compilation cache and populate both host cache and
126   // persistent cache.
LookupPersistentCompilationCacheAndFillCaches(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,TpuPersistentCompilationCacheInterface * persistent_cache,const tpu::TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)127   virtual Status LookupPersistentCompilationCacheAndFillCaches(
128       FunctionLibraryRuntime* flib_runtime,
129       const SessionMetadata* session_metadata,
130       const TpuMeshStateInterface* mesh_state,
131       const std::vector<TensorShape>& dynamic_shapes,
132       const OpInputList& guaranteed_constants,
133       TpuPersistentCompilationCacheInterface* persistent_cache,
134       const tpu::TpuCompilationCacheKey& key,
135       TpuProgramGroupInterface* tpu_program_group) {
136     LOG(FATAL) << "Lookup from a persistent cache is NOT supported.";
137   }
138 
139   // Sleeps for `kSleepSeconds` seconds to give time for TPUCompileOp to finish
140   // before terminating peacefully.
141   static void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done);
142 
143   // Converts the `dynamic_shapes` arguments to the compile operator into
144   // TensorShapes.
145   static Status GetDynamicShapes(OpKernelContext* ctx,
146                                  std::vector<TensorShape>* shapes);
147 
148   tpu::TPUCompileMetadataProto metadata_;
149 
150   // Whether to compile given MLIR module in `mlir_module` instead of
151   // TensorFlow function referenced in `function_`.
152   bool use_mlir_;
153 
154   // Function containing the computation to compile.
155   NameAttrList function_;
156 
157   // A serialized MLIR ModuleOp.
158   std::string mlir_module_;
159   // Fingerprint of the MLIR Module created once on construction to avoid paying
160   // the cost on each invocation.
161   uint64 mlir_module_fingerprint_ = 0;
162 
163   // Number of different programs to compile. This maps to number of cores in
164   // each replica.
165   int num_computations_;
166 
167   // A flag to populate HLO protos field in CompilationResultProto. The HLO
168   // metadata could be large so default to not populating it unless explicitly
169   // requested.
170   bool return_hlo_protos_;
171 
172   // If enabled, DirectSession::Close will unload cache entries created during
173   // the lifetime of the session.
174   bool unload_cache_entry_on_session_close_;
175 
176   // Persistent cache for compiled TPU program for inference.
177   std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache_;
178 
179   Status RegisterXLAFingerprints(const std::vector<TensorShape>& arg_shapes,
180                                  TpuProgramGroupInterface* tpu_program_group,
181                                  uint64 fingerprint);
182 
183  private:
184   TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOpKernelCommon);
185 };
186 }  // namespace tpu
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
190