1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16
17 #include "absl/cleanup/cleanup.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/compiler.h"
20 #include "tensorflow/compiler/xla/service/executable.h"
21 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/stream_executor/device_memory_allocator.h"
29 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
30 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
31 #include "tensorflow/stream_executor/tpu/proto_helper.h"
32 #include "tensorflow/stream_executor/tpu/status_helper.h"
33 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
34 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
35 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
36 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
37 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
38
39 namespace xla {
40
41 namespace {
42
43 using ::tensorflow::tpu::ExecutorApiFn;
44
45 class TpuCompiler : public Compiler {
46 public:
TpuCompiler()47 TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
~TpuCompiler()48 ~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
49
PlatformId() const50 stream_executor::Platform::Id PlatformId() const override {
51 return tensorflow::tpu::GetTpuPlatformId();
52 }
53
RunHloPasses(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)54 StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
55 std::unique_ptr<HloModule> module,
56 stream_executor::StreamExecutor* executor,
57 const CompileOptions& options) override {
58 XLA_HloModule hlo_module;
59 auto cleanup = absl::MakeCleanup([&hlo_module]() {
60 stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
61 ApiConverter::Destroy(&hlo_module.module_config);
62 });
63 hlo_module.module_config = ApiConverter::ToC(module->config());
64 hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
65 auto allocator = ApiConverter::ToC(options.device_allocator);
66 XLA_HloModule result;
67 StatusHelper status;
68 ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
69 compiler_, &hlo_module,
70 static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
71 ->se_executor(),
72 &allocator, &result, status.c_status);
73 if (!status.ok()) {
74 return status.status();
75 }
76 HloModuleProto result_proto =
77 stream_executor::tpu::DeserializeProto<HloModuleProto>(result.proto);
78 stream_executor::tpu::SerializedProto_Free(result.proto);
79 return HloModule::CreateFromProto(result_proto, module->config());
80 }
81
RunBackend(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)82 StatusOr<std::unique_ptr<Executable>> RunBackend(
83 std::unique_ptr<HloModule> module,
84 stream_executor::StreamExecutor* executor,
85 const CompileOptions& options) override {
86 XLA_HloModule hlo_module;
87 auto cleanup = absl::MakeCleanup([&hlo_module]() {
88 stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
89 ApiConverter::Destroy(&hlo_module.module_config);
90 });
91 SE_Executable* result;
92 hlo_module.module_config = ApiConverter::ToC(module->config());
93 hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
94 auto allocator = ApiConverter::ToC(options.device_allocator);
95
96 StatusHelper status;
97 ExecutorApiFn()->TpuCompiler_RunBackendFn(
98 compiler_, &hlo_module,
99 static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
100 ->se_executor(),
101 &allocator, &result, status.c_status);
102 if (!status.ok()) {
103 return status.status();
104 }
105
106 std::unique_ptr<Executable> exec =
107 std::make_unique<TpuExecutable>(result, std::move(module));
108 return exec;
109 }
110
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<stream_executor::StreamExecutor * >> stream_exec,const CompileOptions & options)111 StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
112 std::unique_ptr<HloModuleGroup> module_group,
113 std::vector<std::vector<stream_executor::StreamExecutor*>> stream_exec,
114 const CompileOptions& options) override {
115 XLA_HloModuleGroup se_module_group;
116 se_module_group.proto =
117 stream_executor::tpu::SerializeProto(module_group->ToProto());
118 se_module_group.module_config =
119 new XLA_HloModuleConfig[module_group->size()];
120 int module_group_size = module_group->size();
121 auto cleanup_config =
122 absl::MakeCleanup([&se_module_group, module_group_size]() {
123 for (auto i = 0; i < module_group_size; ++i) {
124 ApiConverter::Destroy(&se_module_group.module_config[i]);
125 }
126 delete[] se_module_group.module_config;
127 });
128 for (int i = 0; i < module_group->size(); ++i) {
129 const auto& config = module_group->module(i).config();
130 se_module_group.module_config[i] = ApiConverter::ToC(config);
131 }
132 std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
133 std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
134 for (int i = 0; i < stream_exec.size(); ++i) {
135 se_lists[i].count = stream_exec[i].size();
136 se_lists_storage.emplace_back(stream_exec[i].size());
137 se_lists[i].exec = se_lists_storage.back().data();
138 for (int j = 0; j < stream_exec[i].size(); ++j) {
139 se_lists[i].exec[j] = static_cast<tensorflow::tpu::TpuExecutor*>(
140 stream_exec[i][j]->implementation())
141 ->se_executor();
142 }
143 }
144
145 SE_DeviceMemoryAllocator allocator =
146 ApiConverter::ToC(options.device_allocator);
147
148 SE_Executable** se_executables = new SE_Executable*[module_group->size()];
149
150 StatusHelper status;
151
152 ExecutorApiFn()->TpuCompiler_CompileFn(
153 compiler_, &se_module_group, se_lists.data(), stream_exec.size(),
154 &allocator, se_executables, status.c_status);
155
156 if (!status.ok()) {
157 return status.status();
158 }
159
160 std::vector<std::unique_ptr<Executable>> executables;
161 for (int i = 0; i < module_group->size(); ++i) {
162 // We get the HloModule from the compiled executable, rather than reusing
163 // the input module from 'module_group', in case the module changed in
164 // some way. For example, if the computation is automatically partitioned
165 // via XLA, the executable's module may have different input/output shapes
166 // than the input module.
167 XLA_HloModule c_module =
168 ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
169 auto cleanup_c_module = absl::MakeCleanup(
170 [&c_module]() { ApiConverter::Destroy(&c_module); });
171 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
172 ApiConverter::FromC(c_module));
173 std::shared_ptr<HloModule> module_shared(module.release());
174 executables.emplace_back(std::make_unique<TpuExecutable>(
175 se_executables[i], std::move(module_shared)));
176 }
177
178 stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
179 delete[] se_executables;
180
181 return executables;
182 }
183
184 // Compiles the HLO module group for ahead-of-time execution. This is
185 // intended for use in static compilation.
186 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)187 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
188 const AotCompilationOptions& options) override {
189 return Unimplemented("This compiler does not support CompileAheadOfTime.");
190 }
191
192 // Returns a function that computes the size in bytes of the logical
193 // buffer that contains a shape.
ShapeSizeBytesFunction() const194 HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
195 return [this](const xla::Shape& shape) {
196 XLA_Shape c_shape;
197 ApiConverter::ToC(shape, &c_shape);
198 int64_t bytes =
199 ExecutorApiFn()->TpuCompiler_ShapeSizeFn(compiler_, &c_shape);
200 ApiConverter::Destroy(&c_shape);
201 return bytes;
202 };
203 }
204
DefaultDeviceShapeRepresentation(const Shape & shape) const205 Shape DefaultDeviceShapeRepresentation(const Shape& shape) const override {
206 XLA_Shape host_shape, device_shape;
207 ApiConverter::ToC(shape, &host_shape);
208 ExecutorApiFn()->TpuCompiler_DefaultDeviceShapeRepresentationFn(
209 compiler_, &host_shape, &device_shape);
210 return ApiConverter::FromC(&device_shape);
211 }
212
213 private:
214 Tpu_Compiler* compiler_;
215 };
216
InitModule()217 static bool InitModule() {
218 xla::Compiler::RegisterCompilerFactory(
219 tensorflow::tpu::GetTpuPlatformId(),
220 []() { return std::make_unique<TpuCompiler>(); });
221 return true;
222 }
223
224 static bool module_initialized = InitModule();
225
226 } // namespace
227 } // namespace xla
228