1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" 23 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 24 #include "tensorflow/compiler/xla/client/compile_only_client.h" 25 #include "tensorflow/compiler/xla/service/computation_placer.h" 26 #include "tensorflow/compiler/xla/service/hlo.pb.h" 27 #include "tensorflow/compiler/xrt/xrt.pb.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" 30 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" 31 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" 32 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 33 #include "tensorflow/core/tpu/tpu_ops_c_api.h" 34 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" 35 36 namespace tensorflow { 37 namespace tpu { 38 39 class TpuAotCompilationOptions : public xla::AotCompilationOptions { 40 public: TpuAotCompilationOptions(int64_t replica_count)41 explicit TpuAotCompilationOptions(int64_t replica_count) 42 : num_cores_(0), replica_count_(replica_count) {} 43 44 // Returns the ID of the platform to which these options apply. PlatformId()45 se::Platform::Id PlatformId() const override { 46 LOG(FATAL) << "Not implemented."; 47 return nullptr; 48 }; 49 set_num_cores(int64_t tpu_cores)50 void set_num_cores(int64_t tpu_cores) { num_cores_ = tpu_cores; } replica_count()51 int64 replica_count() const override { return replica_count_; } num_cores()52 int64 num_cores() const override { return num_cores_; } 53 set_allow_separate_sharding_programs(bool allow)54 void set_allow_separate_sharding_programs(bool allow) { 55 allow_separate_sharding_programs_ = allow; 56 } allow_separate_sharding_programs()57 bool allow_separate_sharding_programs() const { 58 return allow_separate_sharding_programs_; 59 } 60 61 const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> shardable_value_update_pairs()62 shardable_value_update_pairs() const { 63 return shardable_value_update_pairs_; 64 } set_shardable_value_update_pairs(std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs)65 void set_shardable_value_update_pairs( 66 std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) { 67 shardable_value_update_pairs_ = std::move(pairs); 68 } 69 70 private: 71 int64 num_cores_; 72 int64 replica_count_; 73 74 // Whether to allow the compiler to create separte sharding and unsharding 75 // programs, and modify the original program's input/output sharded size. This 76 // is used for XLA-chosen sharding on parameters without an on-device loop: 77 // the caller can invoke sharding first, then (repeatedly) invoke the sharded 78 // main program, and finally invoke the unsharding program when it needs the 79 // full output. 80 bool allow_separate_sharding_programs_ = false; 81 82 // The list of input/output pairs in the main program that could be sharded. 83 std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> 84 shardable_value_update_pairs_; 85 }; 86 87 class TpuProgramGroup : public TpuProgramGroupInterface { 88 public: 89 using Status = ::stream_executor::port::Status; 90 91 // Compiles Mlir or TF function computation by lowering into HLO IR and 92 // returns TPU programs ready for execution. 93 static Status CompileAndBuild( 94 const TpuCompilationRequestProto& compilation_request, 95 const XLA_TpuMeshState* mesh_state, 96 TpuProgramGroupInterface* tpu_program_group_interface); 97 98 // Compiles HLO IR and returns TPU programs ready for execution. 99 static Status CompileAndBuild( 100 const xrt::XLAComputation& xrt_computation_proto, 101 const XLA_TpuMeshState* mesh_state, 102 TpuProgramGroupInterface* tpu_program_group_interface); 103 104 // Initializes `TpuProgramGroup` object with `xla_tpu_programs`. 105 void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs); 106 107 TpuProgramGroup() = default; 108 TpuProgramGroup(TpuProgramGroup&& other); 109 TpuProgramGroup& operator=(TpuProgramGroup&&) = delete; 110 111 bool has_sharding_program() const override; 112 113 size_t program_count() const override; 114 115 int64_t program_size() const override; 116 117 bool LogProgramMemorySummary() override; 118 119 void UnloadAndDestroyPrograms() override; 120 121 Status LogCompilationStats(const TpuCompilationCacheKey& key, 122 absl::Duration duration) override; 123 124 const std::vector<bool>& may_modify_variables_list() const override; 125 void set_may_modify_variables(const std::vector<bool>& may_modify_variables); 126 bool may_modify_variables(int index) const override; 127 128 const std::vector<std::string>& fingerprints() const; 129 void set_fingerprints(); 130 131 const std::string& fingerprint(int index) const override; 132 133 const std::vector<XLA_TpuProgram*>& tpu_programs() const; 134 std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const; 135 const XLA_TpuProgram* tpu_program(int index) const override; 136 void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs); 137 138 const TPUExecutableInfoProto& executable_info(int index) const override; 139 140 const TPUHostTransferInfoProto& host_transfer_info(int index) const override; 141 void set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas); 142 const xla::HloProto* hlo_metadata(int index) const; 143 absl::Span<const xla::HloProto* const> hlo_metadatas() const override; 144 145 // Deserializes `GetTpuProgramResponse` protos from remote cache. 146 Status DeserializeFromRpcResponseProtos( 147 const std::vector<TpuSerializedProto>& rpc_response_protos); 148 149 // Serializes executable proto from the TPU program for the given core 150 // `index`. 151 Status SerializeExecutable(int index, 152 TpuExecutableSerializedProto* executable) const; 153 154 // Serializes compiler metadata of the TPU program for the given core `index`. 155 Status SerializeCompilerMetadata( 156 int index, CompilerMetadataSerializedProto* compiler_metadata) const; 157 158 // Serializes host compute metadata of the TPU program for the given core 159 // `index`. 160 Status SerializeHostComputeMetadata( 161 int index, 162 HostComputeMetadataSerializedProto* host_compute_metadata) const; 163 164 private: 165 TPUExecutableInfoProto ConstructExecutableInfo( 166 const XLA_TpuProgram* tpu_program); 167 TPUHostTransferInfoProto ConstructHostTransferInfo( 168 const XLA_TpuProgram* tpu_program); 169 xla::HloProto ConstructHloMetadata(const XLA_TpuProgram* tpu_program); 170 171 // Update `hlo_metadatas__ptrs_` array from `hlo_metadatas_`. This needs to be 172 // called on `hlo_metadatas_` change(s). 173 void RefreshHloMetadatasPtrs(); 174 175 std::vector<bool> may_modify_variables_; 176 std::vector<std::string> tpu_program_fingerprints_; 177 178 std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned. 179 std::vector<TPUExecutableInfoProto> executable_infos_; 180 std::vector<TPUHostTransferInfoProto> host_transfer_infos_; 181 182 // To be consistent with the TpuProgramGroupInterface::hlo_metadatas() 183 // signature, we store HloProto values in hlo_metadatas_ when 184 // set_hlo_metadata(...) is called, and return their pointers from 185 // hlo_metadatas_ptrs_ when hlo_metadatas() is called. hlo_metadata_ptrs_ is 186 // refreshed whenever hlo_metadatas_ is set or the object is moved. 187 std::vector<xla::HloProto> hlo_metadatas_; // Owned. 188 std::vector<const xla::HloProto*> hlo_metadatas_ptrs_; 189 190 TF_DISALLOW_COPY_AND_ASSIGN(TpuProgramGroup); 191 }; 192 193 } // namespace tpu 194 } // namespace tensorflow 195 196 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ 197