• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
23 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
24 #include "tensorflow/compiler/xla/client/compile_only_client.h"
25 #include "tensorflow/compiler/xla/service/computation_placer.h"
26 #include "tensorflow/compiler/xla/service/hlo.pb.h"
27 #include "tensorflow/compiler/xrt/xrt.pb.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
30 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
31 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
33 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
34 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
35 
36 namespace tensorflow {
37 namespace tpu {
38 
39 class TpuAotCompilationOptions : public xla::AotCompilationOptions {
40  public:
TpuAotCompilationOptions(int64_t replica_count)41   explicit TpuAotCompilationOptions(int64_t replica_count)
42       : num_cores_(0), replica_count_(replica_count) {}
43 
44   // Returns the ID of the platform to which these options apply.
PlatformId()45   se::Platform::Id PlatformId() const override {
46     LOG(FATAL) << "Not implemented.";
47     return nullptr;
48   };
49 
set_num_cores(int64_t tpu_cores)50   void set_num_cores(int64_t tpu_cores) { num_cores_ = tpu_cores; }
replica_count()51   int64 replica_count() const override { return replica_count_; }
num_cores()52   int64 num_cores() const override { return num_cores_; }
53 
set_allow_separate_sharding_programs(bool allow)54   void set_allow_separate_sharding_programs(bool allow) {
55     allow_separate_sharding_programs_ = allow;
56   }
allow_separate_sharding_programs()57   bool allow_separate_sharding_programs() const {
58     return allow_separate_sharding_programs_;
59   }
60 
61   const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
shardable_value_update_pairs()62   shardable_value_update_pairs() const {
63     return shardable_value_update_pairs_;
64   }
set_shardable_value_update_pairs(std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs)65   void set_shardable_value_update_pairs(
66       std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) {
67     shardable_value_update_pairs_ = std::move(pairs);
68   }
69 
70  private:
71   int64 num_cores_;
72   int64 replica_count_;
73 
74   // Whether to allow the compiler to create separte sharding and unsharding
75   // programs, and modify the original program's input/output sharded size. This
76   // is used for XLA-chosen sharding on parameters without an on-device loop:
77   // the caller can invoke sharding first, then (repeatedly) invoke the sharded
78   // main program, and finally invoke the unsharding program when it needs the
79   // full output.
80   bool allow_separate_sharding_programs_ = false;
81 
82   // The list of input/output pairs in the main program that could be sharded.
83   std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
84       shardable_value_update_pairs_;
85 };
86 
87 class TpuProgramGroup : public TpuProgramGroupInterface {
88  public:
89   using Status = ::stream_executor::port::Status;
90 
91   // Compiles Mlir or TF function computation by lowering into HLO IR and
92   // returns TPU programs ready for execution.
93   static Status CompileAndBuild(
94       const TpuCompilationRequestProto& compilation_request,
95       const XLA_TpuMeshState* mesh_state,
96       TpuProgramGroupInterface* tpu_program_group_interface);
97 
98   // Compiles HLO IR and returns TPU programs ready for execution.
99   static Status CompileAndBuild(
100       const xrt::XLAComputation& xrt_computation_proto,
101       const XLA_TpuMeshState* mesh_state,
102       TpuProgramGroupInterface* tpu_program_group_interface);
103 
104   // Initializes `TpuProgramGroup` object with `xla_tpu_programs`.
105   void Initialize(absl::Span<XLA_TpuProgram* const> xla_tpu_programs);
106 
107   TpuProgramGroup() = default;
108   TpuProgramGroup(TpuProgramGroup&& other);
109   TpuProgramGroup& operator=(TpuProgramGroup&&) = delete;
110 
111   bool has_sharding_program() const override;
112 
113   size_t program_count() const override;
114 
115   int64_t program_size() const override;
116 
117   bool LogProgramMemorySummary() override;
118 
119   void UnloadAndDestroyPrograms() override;
120 
121   Status LogCompilationStats(const TpuCompilationCacheKey& key,
122                              absl::Duration duration) override;
123 
124   const std::vector<bool>& may_modify_variables_list() const override;
125   void set_may_modify_variables(const std::vector<bool>& may_modify_variables);
126   bool may_modify_variables(int index) const override;
127 
128   const std::vector<std::string>& fingerprints() const;
129   void set_fingerprints();
130 
131   const std::string& fingerprint(int index) const override;
132 
133   const std::vector<XLA_TpuProgram*>& tpu_programs() const;
134   std::vector<XLA_TpuProgram*> tpu_programs(TpuProgramShardingType type) const;
135   const XLA_TpuProgram* tpu_program(int index) const override;
136   void set_tpu_programs(absl::Span<XLA_TpuProgram* const> tpu_programs);
137 
138   const TPUExecutableInfoProto& executable_info(int index) const override;
139 
140   const TPUHostTransferInfoProto& host_transfer_info(int index) const override;
141   void set_hlo_metadatas(absl::Span<const xla::HloProto> hlo_metadatas);
142   const xla::HloProto* hlo_metadata(int index) const;
143   absl::Span<const xla::HloProto* const> hlo_metadatas() const override;
144 
145   // Deserializes `GetTpuProgramResponse` protos from remote cache.
146   Status DeserializeFromRpcResponseProtos(
147       const std::vector<TpuSerializedProto>& rpc_response_protos);
148 
149   // Serializes executable proto from the TPU program for the given core
150   // `index`.
151   Status SerializeExecutable(int index,
152                              TpuExecutableSerializedProto* executable) const;
153 
154   // Serializes compiler metadata of the TPU program for the given core `index`.
155   Status SerializeCompilerMetadata(
156       int index, CompilerMetadataSerializedProto* compiler_metadata) const;
157 
158   // Serializes host compute metadata of the TPU program for the given core
159   // `index`.
160   Status SerializeHostComputeMetadata(
161       int index,
162       HostComputeMetadataSerializedProto* host_compute_metadata) const;
163 
164  private:
165   TPUExecutableInfoProto ConstructExecutableInfo(
166       const XLA_TpuProgram* tpu_program);
167   TPUHostTransferInfoProto ConstructHostTransferInfo(
168       const XLA_TpuProgram* tpu_program);
169   xla::HloProto ConstructHloMetadata(const XLA_TpuProgram* tpu_program);
170 
171   // Update `hlo_metadatas__ptrs_` array from `hlo_metadatas_`. This needs to be
172   // called on `hlo_metadatas_` change(s).
173   void RefreshHloMetadatasPtrs();
174 
175   std::vector<bool> may_modify_variables_;
176   std::vector<std::string> tpu_program_fingerprints_;
177 
178   std::vector<XLA_TpuProgram*> tpu_programs_;  // Not owned.
179   std::vector<TPUExecutableInfoProto> executable_infos_;
180   std::vector<TPUHostTransferInfoProto> host_transfer_infos_;
181 
182   // To be consistent with the TpuProgramGroupInterface::hlo_metadatas()
183   // signature, we store HloProto values in hlo_metadatas_ when
184   // set_hlo_metadata(...) is called, and return their pointers from
185   // hlo_metadatas_ptrs_ when hlo_metadatas() is called. hlo_metadata_ptrs_ is
186   // refreshed whenever hlo_metadatas_ is set or the object is moved.
187   std::vector<xla::HloProto> hlo_metadatas_;  // Owned.
188   std::vector<const xla::HloProto*> hlo_metadatas_ptrs_;
189 
190   TF_DISALLOW_COPY_AND_ASSIGN(TpuProgramGroup);
191 };
192 
193 }  // namespace tpu
194 }  // namespace tensorflow
195 
196 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_
197