• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 
17 #include "absl/cleanup/cleanup.h"
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/compiler.h"
20 #include "tensorflow/compiler/xla/service/executable.h"
21 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/stream_executor/device_memory_allocator.h"
29 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
30 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
31 #include "tensorflow/stream_executor/tpu/proto_helper.h"
32 #include "tensorflow/stream_executor/tpu/status_helper.h"
33 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
34 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
35 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
36 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
37 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
38 
39 namespace xla {
40 
41 namespace {
42 
43 using ::tensorflow::tpu::ExecutorApiFn;
44 
45 class TpuCompiler : public Compiler {
46  public:
TpuCompiler()47   TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
~TpuCompiler()48   ~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
49 
PlatformId() const50   stream_executor::Platform::Id PlatformId() const override {
51     return tensorflow::tpu::GetTpuPlatformId();
52   }
53 
RunHloPasses(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)54   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
55       std::unique_ptr<HloModule> module,
56       stream_executor::StreamExecutor* executor,
57       const CompileOptions& options) override {
58     XLA_HloModule hlo_module;
59     auto cleanup = absl::MakeCleanup([&hlo_module]() {
60       stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
61       ApiConverter::Destroy(&hlo_module.module_config);
62     });
63     hlo_module.module_config = ApiConverter::ToC(module->config());
64     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
65     auto allocator = ApiConverter::ToC(options.device_allocator);
66     XLA_HloModule result;
67     StatusHelper status;
68     ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
69         compiler_, &hlo_module,
70         static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
71             ->se_executor(),
72         &allocator, &result, status.c_status);
73     if (!status.ok()) {
74       return status.status();
75     }
76     HloModuleProto result_proto =
77         stream_executor::tpu::DeserializeProto<HloModuleProto>(result.proto);
78     stream_executor::tpu::SerializedProto_Free(result.proto);
79     return HloModule::CreateFromProto(result_proto, module->config());
80   }
81 
RunBackend(std::unique_ptr<HloModule> module,stream_executor::StreamExecutor * executor,const CompileOptions & options)82   StatusOr<std::unique_ptr<Executable>> RunBackend(
83       std::unique_ptr<HloModule> module,
84       stream_executor::StreamExecutor* executor,
85       const CompileOptions& options) override {
86     XLA_HloModule hlo_module;
87     auto cleanup = absl::MakeCleanup([&hlo_module]() {
88       stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
89       ApiConverter::Destroy(&hlo_module.module_config);
90     });
91     SE_Executable* result;
92     hlo_module.module_config = ApiConverter::ToC(module->config());
93     hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
94     auto allocator = ApiConverter::ToC(options.device_allocator);
95 
96     StatusHelper status;
97     ExecutorApiFn()->TpuCompiler_RunBackendFn(
98         compiler_, &hlo_module,
99         static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
100             ->se_executor(),
101         &allocator, &result, status.c_status);
102     if (!status.ok()) {
103       return status.status();
104     }
105 
106     std::unique_ptr<Executable> exec =
107         std::make_unique<TpuExecutable>(result, std::move(module));
108     return exec;
109   }
110 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<stream_executor::StreamExecutor * >> stream_exec,const CompileOptions & options)111   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
112       std::unique_ptr<HloModuleGroup> module_group,
113       std::vector<std::vector<stream_executor::StreamExecutor*>> stream_exec,
114       const CompileOptions& options) override {
115     XLA_HloModuleGroup se_module_group;
116     se_module_group.proto =
117         stream_executor::tpu::SerializeProto(module_group->ToProto());
118     se_module_group.module_config =
119         new XLA_HloModuleConfig[module_group->size()];
120     int module_group_size = module_group->size();
121     auto cleanup_config =
122         absl::MakeCleanup([&se_module_group, module_group_size]() {
123           for (auto i = 0; i < module_group_size; ++i) {
124             ApiConverter::Destroy(&se_module_group.module_config[i]);
125           }
126           delete[] se_module_group.module_config;
127         });
128     for (int i = 0; i < module_group->size(); ++i) {
129       const auto& config = module_group->module(i).config();
130       se_module_group.module_config[i] = ApiConverter::ToC(config);
131     }
132     std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
133     std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
134     for (int i = 0; i < stream_exec.size(); ++i) {
135       se_lists[i].count = stream_exec[i].size();
136       se_lists_storage.emplace_back(stream_exec[i].size());
137       se_lists[i].exec = se_lists_storage.back().data();
138       for (int j = 0; j < stream_exec[i].size(); ++j) {
139         se_lists[i].exec[j] = static_cast<tensorflow::tpu::TpuExecutor*>(
140                                   stream_exec[i][j]->implementation())
141                                   ->se_executor();
142       }
143     }
144 
145     SE_DeviceMemoryAllocator allocator =
146         ApiConverter::ToC(options.device_allocator);
147 
148     SE_Executable** se_executables = new SE_Executable*[module_group->size()];
149 
150     StatusHelper status;
151 
152     ExecutorApiFn()->TpuCompiler_CompileFn(
153         compiler_, &se_module_group, se_lists.data(), stream_exec.size(),
154         &allocator, se_executables, status.c_status);
155 
156     if (!status.ok()) {
157       return status.status();
158     }
159 
160     std::vector<std::unique_ptr<Executable>> executables;
161     for (int i = 0; i < module_group->size(); ++i) {
162       // We get the HloModule from the compiled executable, rather than reusing
163       // the input module from 'module_group', in case the module changed in
164       // some way. For example, if the computation is automatically partitioned
165       // via XLA, the executable's module may have different input/output shapes
166       // than the input module.
167       XLA_HloModule c_module =
168           ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
169       auto cleanup_c_module = absl::MakeCleanup(
170           [&c_module]() { ApiConverter::Destroy(&c_module); });
171       TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
172                           ApiConverter::FromC(c_module));
173       std::shared_ptr<HloModule> module_shared(module.release());
174       executables.emplace_back(std::make_unique<TpuExecutable>(
175           se_executables[i], std::move(module_shared)));
176     }
177 
178     stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
179     delete[] se_executables;
180 
181     return executables;
182   }
183 
184   // Compiles the HLO module group for ahead-of-time execution.  This is
185   // intended for use in static compilation.
186   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & options)187   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
188                      const AotCompilationOptions& options) override {
189     return Unimplemented("This compiler does not support CompileAheadOfTime.");
190   }
191 
192   // Returns a function that computes the size in bytes of the logical
193   // buffer that contains a shape.
ShapeSizeBytesFunction() const194   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
195     return [this](const xla::Shape& shape) {
196       XLA_Shape c_shape;
197       ApiConverter::ToC(shape, &c_shape);
198       int64_t bytes =
199           ExecutorApiFn()->TpuCompiler_ShapeSizeFn(compiler_, &c_shape);
200       ApiConverter::Destroy(&c_shape);
201       return bytes;
202     };
203   }
204 
DefaultDeviceShapeRepresentation(const Shape & shape) const205   Shape DefaultDeviceShapeRepresentation(const Shape& shape) const override {
206     XLA_Shape host_shape, device_shape;
207     ApiConverter::ToC(shape, &host_shape);
208     ExecutorApiFn()->TpuCompiler_DefaultDeviceShapeRepresentationFn(
209         compiler_, &host_shape, &device_shape);
210     return ApiConverter::FromC(&device_shape);
211   }
212 
213  private:
214   Tpu_Compiler* compiler_;
215 };
216 
InitModule()217 static bool InitModule() {
218   xla::Compiler::RegisterCompilerFactory(
219       tensorflow::tpu::GetTpuPlatformId(),
220       []() { return std::make_unique<TpuCompiler>(); });
221   return true;
222 }
223 
224 static bool module_initialized = InitModule();
225 
226 }  // namespace
227 }  // namespace xla
228