• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/jit/shape_inference.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/compile_only_client.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/platform/fingerprint.h"
30 #include "tensorflow/core/platform/strcat.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
35 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
37 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
38 
39 namespace tensorflow {
40 namespace tpu {
41 // Forward declaration, defined below.
42 class TpuCompileOpKernelCommon;
43 
44 // A base factory class for creating a `TpuCompileOpKernelImpl` variant.
45 // By design, the actual factory can only be set once.
46 class CompileOpImplFactory {
47  public:
48   virtual ~CompileOpImplFactory() = default;
49 
50   virtual stream_executor::port::StatusOr<
51       std::unique_ptr<TpuCompileOpKernelCommon>>
52   CreateNonMlirImpl(OpKernelConstruction* ctx) = 0;
53 
54   virtual stream_executor::port::StatusOr<
55       std::unique_ptr<TpuCompileOpKernelCommon>>
56   CreateMlirImpl(OpKernelConstruction* ctx) = 0;
57 
58   static CompileOpImplFactory* Get();
59   static void Register(CompileOpImplFactory* factory);
60 
61  private:
62   static CompileOpImplFactory* factory_;
63 };
64 
65 // Abstract base class for TpuCompileOpKernel implementation.
66 class TpuCompileOpKernelCommon {
67  public:
TpuCompileOpKernelCommon(const std::string & mlir_module,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close)68   TpuCompileOpKernelCommon(const std::string& mlir_module,
69                            const tpu::TPUCompileMetadataProto metadata,
70                            int num_computations, bool return_hlo_protos,
71                            bool unload_cache_on_session_close)
72       : metadata_(metadata),
73         use_mlir_(true),
74         mlir_module_(mlir_module),
75         num_computations_(num_computations),
76         return_hlo_protos_(return_hlo_protos),
77         unload_cache_entry_on_session_close_(unload_cache_on_session_close),
78         persistent_cache_(nullptr) {
79     mlir_module_fingerprint_ = tensorflow::Fingerprint64(mlir_module_);
80   }
81 
TpuCompileOpKernelCommon(const NameAttrList & function,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close,std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)82   TpuCompileOpKernelCommon(
83       const NameAttrList& function, const tpu::TPUCompileMetadataProto metadata,
84       int num_computations, bool return_hlo_protos,
85       bool unload_cache_on_session_close,
86       std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)
87       : metadata_(metadata),
88         use_mlir_(false),
89         function_(function),
90         num_computations_(num_computations),
91         return_hlo_protos_(return_hlo_protos),
92         unload_cache_entry_on_session_close_(unload_cache_on_session_close),
93         persistent_cache_(std::move(persistent_cache)) {}
94 
95   virtual ~TpuCompileOpKernelCommon() = default;
96 
97   void Compute(OpKernelContext* ctx);
98 
99   // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler
100   // compiles into TPU programs ready for execution.
101   virtual Status Compile(
102       const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
103       const XLA_TpuMeshState* mesh_state,
104       const std::vector<TensorShape>& arg_shapes,
105       TpuProgramGroupInterface* tpu_program_group) = 0;
106 
107   // Performs shape inference on `computation`, filling shape_info with operator
108   // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
109   static Status RunShapeInferenceOnComputation(
110       const tpu::TPUCompileMetadataProto& metadata,
111       const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
112       FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info);
113 
114  protected:
115   Status ComputeInternal(OpKernelContext* ctx);
116 
117   // Compile TPU program locally and populate the host compilation cache.
118   Status CompileLocallyAndFillHostCache(
119       FunctionLibraryRuntime* flib_runtime,
120       const SessionMetadata* session_metadata,
121       const TpuMeshStateInterface* mesh_state,
122       const std::vector<TensorShape>& dynamic_shapes,
123       const OpInputList& guaranteed_constants,
124       const tpu::TpuCompilationCacheKey& key,
125       TpuProgramGroupInterface* tpu_program_group);
126 
127   // Lookup from persistent compilation cache and populate both host cache and
128   // persistent cache.
LookupPersistentCompilationCacheAndFillCaches(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,TpuPersistentCompilationCacheInterface * persistent_cache,const tpu::TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)129   virtual Status LookupPersistentCompilationCacheAndFillCaches(
130       FunctionLibraryRuntime* flib_runtime,
131       const SessionMetadata* session_metadata,
132       const TpuMeshStateInterface* mesh_state,
133       const std::vector<TensorShape>& dynamic_shapes,
134       const OpInputList& guaranteed_constants,
135       TpuPersistentCompilationCacheInterface* persistent_cache,
136       const tpu::TpuCompilationCacheKey& key,
137       TpuProgramGroupInterface* tpu_program_group) {
138     LOG(FATAL) << "Lookup from a persistent cache is NOT supported.";
139   }
140 
141   // Sleeps for `kSleepSeconds` seconds to give time for TPUCompileOp to finish
142   // before terminating peacefully.
143   static void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done);
144 
145   // Converts the `dynamic_shapes` arguments to the compile operator into
146   // TensorShapes.
147   static Status GetDynamicShapes(OpKernelContext* ctx,
148                                  std::vector<TensorShape>* shapes);
149 
150   // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval
151   // nodes in `graph', using the sharding/index assignments in
152   // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from
153   // original argument/return index to (sharding, per-core argument/return
154   // index) pairs. Node attributes, such as device assignments, are not
155   // preserved on function argument and return values nodes, so we must recreate
156   // them the compilation metadata.
157   static Status AssignDevicesToArgsAndRetvals(
158       absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
159       absl::Span<const tpu::ShardingAndIndex> retval_core_mapping,
160       Graph* graph);
161 
162   // Optimizes `graph`, given the argument descriptions in `metadata` and
163   // `arg_shapes`.
164   static Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata,
165                               const std::vector<PartialTensorShape>& arg_shapes,
166                               std::unique_ptr<Graph>* graph,
167                               FunctionLibraryRuntime* flr,
168                               FunctionLibraryDefinition* fld);
169 
170   // Converts a TF Function into XLA HLO, stores generated HLO module and
171   // accompanying metadata in CompilationResult.
172   Status CompileTFFunctionToHlo(
173       const FunctionLibraryDefinition& flib_def, int graph_def_version,
174       const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
175       const std::vector<TensorShape>& arg_shapes,
176       const GuaranteedConsts& guaranteed_constants,
177       const NameAttrList& function,
178       std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
179       xla::CompileOnlyClient* client,
180       std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
181       std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
182       XlaCompiler::CompilationResult* compilation_result);
183 
184   // Gets information regarding how input arguments are sharded across multiple
185   // cores.
186   Status GetShardingInfo(
187       absl::Span<const TensorShape> arg_shapes,
188       const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
189       std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
190       std::vector<std::vector<xla::Shape>>* per_core_arg_shapes);
191 
192   // Populates the mapping from return value to ShardingAndIndex.
193   Status AssignReturnValueToCore(
194       std::vector<tpu::ShardingAndIndex>* retval_core_mapping);
195 
196   // Populates the arguments, core mapping and per core argument shape for the
197   // computation.
198   Status BuildComputationArgumentDescriptions(
199       const std::vector<TensorShape>& arg_shapes,
200       const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
201       std::vector<XlaCompiler::Argument>* args,
202       std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
203       std::vector<std::vector<xla::Shape>>* per_core_arg_shapes);
204 
205   const tpu::TPUCompileMetadataProto metadata_;
206 
207   // Whether to compile given MLIR module in `mlir_module` instead of
208   // TensorFlow function referenced in `function_`.
209   bool use_mlir_;
210 
211   // Function containing the computation to compile.
212   NameAttrList function_;
213 
214   // A serialized MLIR ModuleOp.
215   std::string mlir_module_;
216   // Fingerprint of the MLIR Module created once on construction to avoid paying
217   // the cost on each invocation.
218   uint64 mlir_module_fingerprint_ = 0;
219 
220   // Number of different programs to compile. This maps to number of cores in
221   // each replica.
222   int num_computations_;
223 
224   // A flag to populate HLO protos field in CompilationResultProto. The HLO
225   // metadata could be large so default to not populating it unless explicitly
226   // requested.
227   bool return_hlo_protos_;
228 
229   // If enabled, DirectSession::Close will unload cache entries created during
230   // the lifetime of the session.
231   bool unload_cache_entry_on_session_close_;
232 
233   // Persistent cache for compiled TPU program for inference.
234   std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache_;
235 
236  private:
237   TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOpKernelCommon);
238 };
239 }  // namespace tpu
240 }  // namespace tensorflow
241 
242 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_
243