1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/jit/shape_inference.h" 24 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/platform/fingerprint.h" 28 #include "tensorflow/core/platform/strcat.h" 29 #include "tensorflow/core/platform/types.h" 30 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" 31 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" 32 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" 33 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" 34 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 35 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 36 37 namespace tensorflow { 38 namespace tpu { 39 // Forward declaration, defined below. 40 class TpuCompileOpKernelCommon; 41 42 // A base factory class for creating a `TpuCompileOpKernelImpl` variant. 43 // By design, the actual factory can only be set once. 44 class CompileOpImplFactory { 45 public: 46 virtual ~CompileOpImplFactory() = default; 47 48 virtual stream_executor::port::StatusOr< 49 std::unique_ptr<TpuCompileOpKernelCommon>> 50 CreateNonMlirImpl(OpKernelConstruction* ctx) = 0; 51 52 virtual stream_executor::port::StatusOr< 53 std::unique_ptr<TpuCompileOpKernelCommon>> 54 CreateMlirImpl(OpKernelConstruction* ctx) = 0; 55 56 static CompileOpImplFactory* Get(); 57 static void Register(CompileOpImplFactory* factory); 58 59 private: 60 static CompileOpImplFactory* factory_; 61 }; 62 63 // Abstract base class for TpuCompileOpKernel implementation. 64 class TpuCompileOpKernelCommon { 65 public: TpuCompileOpKernelCommon(const std::string & mlir_module,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close)66 TpuCompileOpKernelCommon(const std::string& mlir_module, 67 const tpu::TPUCompileMetadataProto metadata, 68 int num_computations, bool return_hlo_protos, 69 bool unload_cache_on_session_close) 70 : metadata_(metadata), 71 use_mlir_(true), 72 mlir_module_(mlir_module), 73 num_computations_(num_computations), 74 return_hlo_protos_(return_hlo_protos), 75 unload_cache_entry_on_session_close_(unload_cache_on_session_close), 76 persistent_cache_(nullptr) { 77 mlir_module_fingerprint_ = tensorflow::Fingerprint64(mlir_module_); 78 } 79 TpuCompileOpKernelCommon(const NameAttrList & function,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close,std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)80 TpuCompileOpKernelCommon( 81 const NameAttrList& function, const tpu::TPUCompileMetadataProto metadata, 82 int num_computations, bool return_hlo_protos, 83 bool unload_cache_on_session_close, 84 std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache) 85 : metadata_(metadata), 86 use_mlir_(false), 87 function_(function), 88 num_computations_(num_computations), 89 return_hlo_protos_(return_hlo_protos), 90 unload_cache_entry_on_session_close_(unload_cache_on_session_close), 91 persistent_cache_(std::move(persistent_cache)) {} 92 93 virtual ~TpuCompileOpKernelCommon() = default; 94 95 void Compute(OpKernelContext* ctx); 96 97 // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler 98 // compiles into TPU programs ready for execution. 99 virtual Status Compile( 100 const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation, 101 const XLA_TpuMeshState* mesh_state, 102 const std::vector<TensorShape>& arg_shapes, 103 TpuProgramGroupInterface* tpu_program_group) = 0; 104 105 // Performs shape inference on `computation`, filling shape_info with operator 106 // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`. 107 static Status RunShapeInferenceOnComputation( 108 const tpu::TPUCompileMetadataProto& metadata, 109 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph, 110 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info); 111 112 protected: 113 Status ComputeInternal(OpKernelContext* ctx); 114 115 // Compile TPU program locally and populate the host compilation cache. 116 Status CompileLocallyAndFillHostCache( 117 FunctionLibraryRuntime* flib_runtime, 118 const SessionMetadata* session_metadata, 119 const TpuMeshStateInterface* mesh_state, 120 const std::vector<TensorShape>& dynamic_shapes, 121 const OpInputList& guaranteed_constants, 122 const tpu::TpuCompilationCacheKey& key, 123 TpuProgramGroupInterface* tpu_program_group); 124 125 // Lookup from persistent compilation cache and populate both host cache and 126 // persistent cache. LookupPersistentCompilationCacheAndFillCaches(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,TpuPersistentCompilationCacheInterface * persistent_cache,const tpu::TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)127 virtual Status LookupPersistentCompilationCacheAndFillCaches( 128 FunctionLibraryRuntime* flib_runtime, 129 const SessionMetadata* session_metadata, 130 const TpuMeshStateInterface* mesh_state, 131 const std::vector<TensorShape>& dynamic_shapes, 132 const OpInputList& guaranteed_constants, 133 TpuPersistentCompilationCacheInterface* persistent_cache, 134 const tpu::TpuCompilationCacheKey& key, 135 TpuProgramGroupInterface* tpu_program_group) { 136 LOG(FATAL) << "Lookup from a persistent cache is NOT supported."; 137 } 138 139 // Sleeps for `kSleepSeconds` seconds to give time for TPUCompileOp to finish 140 // before terminating peacefully. 141 static void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done); 142 143 // Converts the `dynamic_shapes` arguments to the compile operator into 144 // TensorShapes. 145 static Status GetDynamicShapes(OpKernelContext* ctx, 146 std::vector<TensorShape>* shapes); 147 148 tpu::TPUCompileMetadataProto metadata_; 149 150 // Whether to compile given MLIR module in `mlir_module` instead of 151 // TensorFlow function referenced in `function_`. 152 bool use_mlir_; 153 154 // Function containing the computation to compile. 155 NameAttrList function_; 156 157 // A serialized MLIR ModuleOp. 158 std::string mlir_module_; 159 // Fingerprint of the MLIR Module created once on construction to avoid paying 160 // the cost on each invocation. 161 uint64 mlir_module_fingerprint_ = 0; 162 163 // Number of different programs to compile. This maps to number of cores in 164 // each replica. 165 int num_computations_; 166 167 // A flag to populate HLO protos field in CompilationResultProto. The HLO 168 // metadata could be large so default to not populating it unless explicitly 169 // requested. 170 bool return_hlo_protos_; 171 172 // If enabled, DirectSession::Close will unload cache entries created during 173 // the lifetime of the session. 174 bool unload_cache_entry_on_session_close_; 175 176 // Persistent cache for compiled TPU program for inference. 177 std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache_; 178 179 Status RegisterXLAFingerprints(const std::vector<TensorShape>& arg_shapes, 180 TpuProgramGroupInterface* tpu_program_group, 181 uint64 fingerprint); 182 183 private: 184 TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOpKernelCommon); 185 }; 186 } // namespace tpu 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 190