• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/tests/mlir_gpu_test_base.h"
17 
18 #include "llvm/IR/LLVMContext.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "mlir/IR/DialectRegistry.h"  // from @llvm-project
21 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
22 #include "mlir/Parser/Parser.h"  // from @llvm-project
23 #include "tensorflow/compiler/xla/debug_options_flags.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
26 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
28 
29 namespace xla {
30 namespace gpu {
31 
MlirGpuTestBase()32 MlirGpuTestBase::MlirGpuTestBase() {
33   se::Platform* platform =
34       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName())
35           .value();
36   BackendOptions options;
37   options.set_platform(platform);
38   backend_ = xla::Backend::CreateBackend(options).value();
39 }
40 
BorrowStream()41 StreamPool::Ptr MlirGpuTestBase::BorrowStream() {
42   return *backend_->BorrowStream(backend_->default_device_ordinal());
43 }
44 
CompileMlirModule(mlir::ModuleOp module,se::Stream * stream)45 StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirModule(
46     mlir::ModuleOp module, se::Stream* stream) {
47   llvm::LLVMContext llvm_context;
48   auto llvm_module = std::make_unique<llvm::Module>("", llvm_context);
49 #if TENSORFLOW_USE_ROCM
50   llvm_module->setTargetTriple(amdgpu::TargetTriple());
51   llvm_module->setDataLayout(amdgpu::DataLayout());
52 #else
53   llvm_module->setTargetTriple(nvptx::TargetTriple());
54   llvm_module->setDataLayout(nvptx::DataLayout());
55 #endif
56 
57   se::StreamExecutor* stream_exec = stream->parent();
58   GpuDeviceInfo gpu_device_info = GetGpuDeviceInfo(stream_exec);
59   IrEmitterContext ir_emitter_context(
60       /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr,
61       backend_->platform()->Name(), gpu_device_info,
62       stream_exec->GetDeviceDescription().cuda_compute_capability(),
63       stream_exec->GetDeviceDescription().rocm_compute_capability(),
64       /*mlir_context=*/nullptr, llvm_module.get());
65 
66   HloModuleConfig module_config;
67   module_config.set_debug_options(GetDebugOptionsFromFlags());
68   return CompileLmhloToExecutable(
69       static_cast<GpuCompiler*>(backend_->compiler()), module, "TestModule",
70       module_config, Compiler::CompileOptions(), "main", stream_exec,
71       std::move(llvm_module), &ir_emitter_context);
72 }
73 
RunMlirModule(mlir::ModuleOp module,se::Stream * stream,absl::Span<const se::DeviceMemoryBase> arguments)74 StatusOr<ExecutionOutput> MlirGpuTestBase::RunMlirModule(
75     mlir::ModuleOp module, se::Stream* stream,
76     absl::Span<const se::DeviceMemoryBase> arguments) {
77   TF_ASSIGN_OR_RETURN(auto executable, CompileMlirModule(module, stream));
78 
79   ExecutableRunOptions executable_run_options;
80   executable_run_options.set_stream(stream);
81   executable_run_options.set_allocator(backend_->memory_allocator());
82   ServiceExecutableRunOptions run_options(executable_run_options,
83                                           backend_->StreamBorrower());
84   std::vector<ExecutionInput> execution_inputs;
85   execution_inputs.reserve(arguments.size());
86 
87   for (auto arg : arguments) {
88     Shape shape =
89         ShapeUtil::MakeShape(xla::U8, {static_cast<int64_t>(arg.size())});
90     execution_inputs.emplace_back(shape);
91     execution_inputs.back().SetBuffer({}, MaybeOwningDeviceMemory(arg));
92   }
93 
94   TF_ASSIGN_OR_RETURN(auto output,
95                       executable->ExecuteAsyncOnStream(
96                           &run_options, std::move(execution_inputs),
97                           /*hlo_execution_profile=*/nullptr));
98 
99   TF_CHECK_OK(stream->BlockHostUntilDone());
100 
101   return std::move(output);
102 }
103 
104 StatusOr<std::vector<std::vector<uint8_t>>>
RunMlirModuleWithHostBuffers(mlir::ModuleOp module,std::vector<absl::Span<uint8_t>> arguments)105 MlirGpuTestBase::RunMlirModuleWithHostBuffers(
106     mlir::ModuleOp module, std::vector<absl::Span<uint8_t>> arguments) {
107   auto* allocator = backend_->memory_allocator();
108   std::vector<se::OwningDeviceMemory> owning_memory;
109   owning_memory.reserve(arguments.size());
110   for (auto host_buffer : arguments) {
111     owning_memory.push_back(
112         allocator
113             ->Allocate(backend_->default_device_ordinal(), host_buffer.size())
114             .value());
115   }
116   auto stream =
117       backend_->BorrowStream(backend_->default_device_ordinal()).value();
118   std::vector<se::DeviceMemoryBase> args;
119   for (int i = 0; i < owning_memory.size(); i++) {
120     se::DeviceMemoryBase memory(*owning_memory[i]);
121     stream->ThenMemcpy(&memory, static_cast<void*>(arguments[i].data()),
122                        memory.size());
123     args.push_back(memory);
124   }
125   TF_ASSIGN_OR_RETURN(ExecutionOutput output,
126                       RunMlirModule(module, stream.get(), args));
127 
128   std::vector<std::vector<uint8_t>> host_outputs;
129   for (const auto& result : output.Result().buffers().leaves()) {
130     host_outputs.emplace_back();
131     host_outputs.back().resize(result.second.size());
132     stream->ThenMemcpy(static_cast<void*>(host_outputs.back().data()),
133                        result.second, result.second.size());
134   }
135   TF_CHECK_OK(stream->BlockHostUntilDone());
136   return host_outputs;
137 }
138 
ParseMlirModule(absl::string_view module_text,mlir::MLIRContext & context)139 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> MlirGpuTestBase::ParseMlirModule(
140     absl::string_view module_text, mlir::MLIRContext& context) {
141   mlir::DialectRegistry registry;
142   xla::gpu::IrEmitterUnnested::GetDependentDialects(registry);
143   context.appendDialectRegistry(registry);
144   llvm::SourceMgr source_mgr;
145   std::string diagnostic_str;
146   llvm::raw_string_ostream os(diagnostic_str);
147   mlir::SourceMgrDiagnosticHandler handler(source_mgr, &context, os);
148 
149   mlir::OwningOpRef<mlir::ModuleOp> module =
150       mlir::parseSourceString<mlir::ModuleOp>(
151           llvm::StringRef(module_text.data(), module_text.size()), &context);
152   if (!module) {
153     return InvalidArgument("Failed to parse MLIR module: %s", diagnostic_str);
154   }
155   return module;
156 }
157 
158 StatusOr<std::vector<std::vector<uint8_t>>>
RunMlirTextWithHostBuffers(absl::string_view module_text,std::vector<absl::Span<uint8_t>> arguments)159 MlirGpuTestBase::RunMlirTextWithHostBuffers(
160     absl::string_view module_text, std::vector<absl::Span<uint8_t>> arguments) {
161   mlir::MLIRContext context;
162   TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
163                       ParseMlirModule(module_text, context));
164   return RunMlirModuleWithHostBuffers(*module, arguments);
165 }
166 
CompileMlirText(absl::string_view module_text)167 StatusOr<std::unique_ptr<Executable>> MlirGpuTestBase::CompileMlirText(
168     absl::string_view module_text) {
169   mlir::MLIRContext context;
170   TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
171                       ParseMlirModule(module_text, context));
172   auto stream =
173       backend_->BorrowStream(backend_->default_device_ordinal()).value();
174   return CompileMlirModule(*module, stream.get());
175 }
176 
177 }  // namespace gpu
178 }  // namespace xla
179