• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/backend.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/platform_util.h"
25 #include "tensorflow/compiler/xla/test_helpers.h"
26 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
27 #include "tensorflow/core/platform/test.h"
28 #include "tensorflow/stream_executor/stream_executor.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 // Creating dummy data structure needed to initialize a GpuDummyCompiler
34 PLATFORM_DEFINE_ID(kDummyTestId);
35 constexpr char kDummyTriple[] = "dummy-triple";
36 constexpr char kDummyLayout[] = "e";
37 
38 // This class is a dummy implementation of GpuCompiler and is targeted for unit
39 // test only
40 class GpuDummyCompiler : public GpuCompiler {
41  public:
GpuDummyCompiler()42   GpuDummyCompiler() : GpuCompiler(kDummyTestId, kDummyTriple, kDummyLayout) {}
43 
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)44   Status OptimizeHloConvolutionCanonicalization(
45       HloModule* hlo_module, se::StreamExecutor* stream_exec,
46       se::DeviceMemoryAllocator* device_allocator) {
47     return Status::OK();
48   }
49 
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)50   Status OptimizeHloPostLayoutAssignment(
51       HloModule* hlo_module, se::StreamExecutor* stream_exec,
52       se::DeviceMemoryAllocator* device_allocator) {
53     return Status::OK();
54   }
55 
GetGpuVersion(se::StreamExecutor *)56   GpuVersion GetGpuVersion(se::StreamExecutor*) override {
57     return se::CudaComputeCapability{0, 0};
58   }
59 
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)60   StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
61       const HloModuleConfig& module_config, llvm::Module* llvm_module,
62       GpuVersion gpu_version, se::StreamExecutor* stream_exec, bool relocatable,
63       const HloModule* debug_module) {
64     std::vector<uint8> compiled_results;
65     return std::pair<std::string, std::vector<uint8>>(
66         "", std::move(compiled_results));
67   }
68 };
69 }  // namespace gpu
70 
71 namespace {
72 
73 class LLVMCompilerTest : public ::testing::Test {
74  public:
SetUp()75   void SetUp() override {
76     Platform *platform = FindPlatform();
77     ASSERT_NE(platform, nullptr);
78 
79     BackendOptions backend_options;
80     backend_options.set_platform(platform);
81     StatusOr<std::unique_ptr<Backend>> backend_or_status =
82         Backend::CreateBackend(backend_options);
83     ASSERT_IS_OK(backend_or_status.status());
84     backend_ = backend_or_status.ConsumeValueOrDie();
85   }
86 
~LLVMCompilerTest()87   ~LLVMCompilerTest() override {}
88 
89  protected:
90   using Platform = se::Platform;
91 
LLVMCompilerTest(string platform_name)92   explicit LLVMCompilerTest(string platform_name)
93       : platform_name_(std::move(platform_name)) {}
94 
TestCompilerHooks(LLVMCompiler * compiler)95   void TestCompilerHooks(LLVMCompiler *compiler) {
96     int pre_opt_hook_call_count = 0;
97     int post_opt_hook_call_count = 0;
98 
99     auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
100       ++pre_opt_hook_call_count;
101       return Status::OK();
102     };
103     auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
104       ++post_opt_hook_call_count;
105       return Status::OK();
106     };
107 
108     // Create HLO module, and run the compiler.
109     auto builder = HloComputation::Builder(TestName());
110     builder.AddInstruction(
111         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
112 
113     auto hlo_module = CreateNewVerifiedModule();
114     hlo_module->AddEntryComputation(builder.Build());
115 
116     compiler->SetPreOptimizationHook(pre_opt_hook);
117     compiler->SetPostOptimizationHook(post_opt_hook);
118 
119     ASSERT_TRUE(compiler
120                     ->RunBackend(std::move(hlo_module),
121                                  backend_->default_stream_executor(),
122                                  /*device_allocator=*/nullptr)
123                     .ok());
124 
125     // Test that hooks were called.
126     EXPECT_EQ(1, pre_opt_hook_call_count);
127     EXPECT_EQ(1, post_opt_hook_call_count);
128   }
129 
TestMultiModuleCompilation(LLVMCompiler * compiler)130   void TestMultiModuleCompilation(LLVMCompiler *compiler) {
131     HloComputation::Builder builder(TestName());
132     builder.AddInstruction(
133         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
134 
135     std::unique_ptr<HloModule> hlo_module = CreateNewVerifiedModule();
136     hlo_module->AddEntryComputation(builder.Build());
137 
138     auto module_group = absl::make_unique<HloModuleGroup>("test_module_group");
139     module_group->push_back(hlo_module->Clone());
140     module_group->push_back(std::move(hlo_module));
141 
142     std::vector<std::vector<se::StreamExecutor *>> executors;
143     executors.push_back({backend_->default_stream_executor()});
144     executors.push_back({backend_->default_stream_executor()});
145 
146     EXPECT_IS_OK(compiler->Compile(std::move(module_group),
147                                    std::move(executors),
148                                    /*device_allocator=*/nullptr));
149   }
150 
151  private:
FindPlatform()152   Platform *FindPlatform() {
153     auto status_or_platform = PlatformUtil::GetPlatform(platform_name_);
154     return status_or_platform.ok() ? status_or_platform.ValueOrDie() : nullptr;
155   }
156 
157   string platform_name_;
158   std::unique_ptr<Backend> backend_;
159 
TestName()160   static string TestName() {
161     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
162   }
163 
CreateNewVerifiedModule()164   std::unique_ptr<HloModule> CreateNewVerifiedModule() {
165     HloModuleConfig config;
166     config.set_debug_options(GetDebugOptionsFromFlags());
167     return absl::make_unique<VerifiedHloModule>(
168         TestName(), config, /*verifier_layout_sensitive=*/false,
169         /*allow_mixed_precision_in_hlo_verifier=*/true,
170         backend_->compiler()->ShapeSizeBytesFunction());
171   }
172 };
173 
174 class CpuCompilerTest : public LLVMCompilerTest {
175  public:
CpuCompilerTest()176   CpuCompilerTest() : LLVMCompilerTest("Host") {}
177 };
178 
179 class GpuCompilerTest : public LLVMCompilerTest {
180  public:
GpuCompilerTest()181   GpuCompilerTest() : LLVMCompilerTest("GPU") {}
182 };
183 
TEST_F(CpuCompilerTest,HooksTest)184 TEST_F(CpuCompilerTest, HooksTest) {
185   cpu::CpuCompiler compiler;
186   TestCompilerHooks(&compiler);
187 }
188 
TEST_F(GpuCompilerTest,HooksTest)189 TEST_F(GpuCompilerTest, HooksTest) {
190   gpu::GpuDummyCompiler compiler;
191   TestCompilerHooks(&compiler);
192 }
193 
TEST_F(CpuCompilerTest,CpuMultiModuleCompilation)194 TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) {
195   cpu::CpuCompiler compiler;
196   TestMultiModuleCompilation(&compiler);
197 }
198 
TEST_F(GpuCompilerTest,GpuMultModuleCompilation)199 TEST_F(GpuCompilerTest, GpuMultModuleCompilation) {
200   gpu::GpuDummyCompiler compiler;
201   TestMultiModuleCompilation(&compiler);
202 }
203 }  // namespace
204 }  // namespace xla
205