1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "absl/types/span.h" 22 #include "absl/types/variant.h" 23 #include "tensorflow/compiler/jit/shape_inference.h" 24 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 25 #include "tensorflow/compiler/xla/client/compile_only_client.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/platform/fingerprint.h" 30 #include "tensorflow/core/platform/strcat.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" 33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" 34 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" 35 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" 36 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 37 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 38 39 namespace tensorflow { 40 namespace tpu { 41 // Forward declaration, defined below. 42 class TpuCompileOpKernelCommon; 43 44 // A base factory class for creating a `TpuCompileOpKernelImpl` variant. 45 // By design, the actual factory can only be set once. 46 class CompileOpImplFactory { 47 public: 48 virtual ~CompileOpImplFactory() = default; 49 50 virtual stream_executor::port::StatusOr< 51 std::unique_ptr<TpuCompileOpKernelCommon>> 52 CreateNonMlirImpl(OpKernelConstruction* ctx) = 0; 53 54 virtual stream_executor::port::StatusOr< 55 std::unique_ptr<TpuCompileOpKernelCommon>> 56 CreateMlirImpl(OpKernelConstruction* ctx) = 0; 57 58 static CompileOpImplFactory* Get(); 59 static void Register(CompileOpImplFactory* factory); 60 61 private: 62 static CompileOpImplFactory* factory_; 63 }; 64 65 // Abstract base class for TpuCompileOpKernel implementation. 66 class TpuCompileOpKernelCommon { 67 public: TpuCompileOpKernelCommon(const std::string & mlir_module,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close)68 TpuCompileOpKernelCommon(const std::string& mlir_module, 69 const tpu::TPUCompileMetadataProto metadata, 70 int num_computations, bool return_hlo_protos, 71 bool unload_cache_on_session_close) 72 : metadata_(metadata), 73 use_mlir_(true), 74 mlir_module_(mlir_module), 75 num_computations_(num_computations), 76 return_hlo_protos_(return_hlo_protos), 77 unload_cache_entry_on_session_close_(unload_cache_on_session_close), 78 persistent_cache_(nullptr) { 79 mlir_module_fingerprint_ = tensorflow::Fingerprint64(mlir_module_); 80 } 81 TpuCompileOpKernelCommon(const NameAttrList & function,const tpu::TPUCompileMetadataProto metadata,int num_computations,bool return_hlo_protos,bool unload_cache_on_session_close,std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache)82 TpuCompileOpKernelCommon( 83 const NameAttrList& function, const tpu::TPUCompileMetadataProto metadata, 84 int num_computations, bool return_hlo_protos, 85 bool unload_cache_on_session_close, 86 std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache) 87 : metadata_(metadata), 88 use_mlir_(false), 89 function_(function), 90 num_computations_(num_computations), 91 return_hlo_protos_(return_hlo_protos), 92 unload_cache_entry_on_session_close_(unload_cache_on_session_close), 93 persistent_cache_(std::move(persistent_cache)) {} 94 95 virtual ~TpuCompileOpKernelCommon() = default; 96 97 void Compute(OpKernelContext* ctx); 98 99 // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler 100 // compiles into TPU programs ready for execution. 101 virtual Status Compile( 102 const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation, 103 const XLA_TpuMeshState* mesh_state, 104 const std::vector<TensorShape>& arg_shapes, 105 TpuProgramGroupInterface* tpu_program_group) = 0; 106 107 // Performs shape inference on `computation`, filling shape_info with operator 108 // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`. 109 static Status RunShapeInferenceOnComputation( 110 const tpu::TPUCompileMetadataProto& metadata, 111 const std::vector<PartialTensorShape>& arg_shapes, Graph* graph, 112 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info); 113 114 protected: 115 Status ComputeInternal(OpKernelContext* ctx); 116 117 // Compile TPU program locally and populate the host compilation cache. 118 Status CompileLocallyAndFillHostCache( 119 FunctionLibraryRuntime* flib_runtime, 120 const SessionMetadata* session_metadata, 121 const TpuMeshStateInterface* mesh_state, 122 const std::vector<TensorShape>& dynamic_shapes, 123 const OpInputList& guaranteed_constants, 124 const tpu::TpuCompilationCacheKey& key, 125 TpuProgramGroupInterface* tpu_program_group); 126 127 // Lookup from persistent compilation cache and populate both host cache and 128 // persistent cache. LookupPersistentCompilationCacheAndFillCaches(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,TpuPersistentCompilationCacheInterface * persistent_cache,const tpu::TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)129 virtual Status LookupPersistentCompilationCacheAndFillCaches( 130 FunctionLibraryRuntime* flib_runtime, 131 const SessionMetadata* session_metadata, 132 const TpuMeshStateInterface* mesh_state, 133 const std::vector<TensorShape>& dynamic_shapes, 134 const OpInputList& guaranteed_constants, 135 TpuPersistentCompilationCacheInterface* persistent_cache, 136 const tpu::TpuCompilationCacheKey& key, 137 TpuProgramGroupInterface* tpu_program_group) { 138 LOG(FATAL) << "Lookup from a persistent cache is NOT supported."; 139 } 140 141 // Sleeps for `kSleepSeconds` seconds to give time for TPUCompileOp to finish 142 // before terminating peacefully. 143 static void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done); 144 145 // Converts the `dynamic_shapes` arguments to the compile operator into 146 // TensorShapes. 147 static Status GetDynamicShapes(OpKernelContext* ctx, 148 std::vector<TensorShape>* shapes); 149 150 // Adds TPU_REPLICATED_CORE device assignments to the _Arg and _Retval 151 // nodes in `graph', using the sharding/index assignments in 152 // `arg_core_mapping` and `retval_core_mapping`. The mappings are maps from 153 // original argument/return index to (sharding, per-core argument/return 154 // index) pairs. Node attributes, such as device assignments, are not 155 // preserved on function argument and return values nodes, so we must recreate 156 // them the compilation metadata. 157 static Status AssignDevicesToArgsAndRetvals( 158 absl::Span<const tpu::ShardingAndIndex> arg_core_mapping, 159 absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, 160 Graph* graph); 161 162 // Optimizes `graph`, given the argument descriptions in `metadata` and 163 // `arg_shapes`. 164 static Status OptimizeGraph(const tpu::TPUCompileMetadataProto& metadata, 165 const std::vector<PartialTensorShape>& arg_shapes, 166 std::unique_ptr<Graph>* graph, 167 FunctionLibraryRuntime* flr, 168 FunctionLibraryDefinition* fld); 169 170 // Converts a TF Function into XLA HLO, stores generated HLO module and 171 // accompanying metadata in CompilationResult. 172 Status CompileTFFunctionToHlo( 173 const FunctionLibraryDefinition& flib_def, int graph_def_version, 174 const XlaCompiler::ShapeRepresentationFn shape_representation_fn, 175 const std::vector<TensorShape>& arg_shapes, 176 const GuaranteedConsts& guaranteed_constants, 177 const NameAttrList& function, 178 std::function<Status(ResourceMgr*)> populate_resource_manager_fn, 179 xla::CompileOnlyClient* client, 180 std::vector<tpu::ShardingAndIndex>* arg_core_mapping, 181 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes, 182 XlaCompiler::CompilationResult* compilation_result); 183 184 // Gets information regarding how input arguments are sharded across multiple 185 // cores. 186 Status GetShardingInfo( 187 absl::Span<const TensorShape> arg_shapes, 188 const XlaCompiler::ShapeRepresentationFn shape_representation_fn, 189 std::vector<tpu::ShardingAndIndex>* arg_core_mapping, 190 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes); 191 192 // Populates the mapping from return value to ShardingAndIndex. 193 Status AssignReturnValueToCore( 194 std::vector<tpu::ShardingAndIndex>* retval_core_mapping); 195 196 // Populates the arguments, core mapping and per core argument shape for the 197 // computation. 198 Status BuildComputationArgumentDescriptions( 199 const std::vector<TensorShape>& arg_shapes, 200 const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler, 201 std::vector<XlaCompiler::Argument>* args, 202 std::vector<tpu::ShardingAndIndex>* arg_core_mapping, 203 std::vector<std::vector<xla::Shape>>* per_core_arg_shapes); 204 205 const tpu::TPUCompileMetadataProto metadata_; 206 207 // Whether to compile given MLIR module in `mlir_module` instead of 208 // TensorFlow function referenced in `function_`. 209 bool use_mlir_; 210 211 // Function containing the computation to compile. 212 NameAttrList function_; 213 214 // A serialized MLIR ModuleOp. 215 std::string mlir_module_; 216 // Fingerprint of the MLIR Module created once on construction to avoid paying 217 // the cost on each invocation. 218 uint64 mlir_module_fingerprint_ = 0; 219 220 // Number of different programs to compile. This maps to number of cores in 221 // each replica. 222 int num_computations_; 223 224 // A flag to populate HLO protos field in CompilationResultProto. The HLO 225 // metadata could be large so default to not populating it unless explicitly 226 // requested. 227 bool return_hlo_protos_; 228 229 // If enabled, DirectSession::Close will unload cache entries created during 230 // the lifetime of the session. 231 bool unload_cache_entry_on_session_close_; 232 233 // Persistent cache for compiled TPU program for inference. 234 std::unique_ptr<TpuPersistentCompilationCacheInterface> persistent_cache_; 235 236 private: 237 TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOpKernelCommon); 238 }; 239 } // namespace tpu 240 } // namespace tensorflow 241 242 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_COMMON_H_ 243