1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/Target/TargetMachine.h"
17 #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
18 #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
19 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
20 #if GOOGLE_CUDA
21 #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/nvptx_helper.h"
23 #endif
24 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
28 #include "tensorflow/core/platform/init_main.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/command_line_flags.h"
31 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
32
33 const char* const kUsage = R"(
34 This tool reads in an HloModule from a file, compiles it using the NVPTX
35 compiler and prints out the LLVM IR generated by the IR emitter. The LLVM IR is
36 not optimized by the LLVM pass pipeline, so this tool can be used to unit test
37 the XLA GPU IR emitters.
38
39 Note that the LLVM IR does not contain the *full* module, but only parts that
40 will be code generated into PTX. The NVPTX compiler also generates a
41 GpuExecutable on the side that is not printed.
42
43 When passed the parameter `--ptx`, the LLVM IR will be optimized and PTX
44 will be emitted and printed instead of the non-optimized LLVM.
45 By default SM 70 is targeted. But this can be changed with `--sm=SM`.)";
46
47 namespace {
CompileAndPrintLlvmIr(const std::string & hlo_text,bool generate_ptx,int sm)48 xla::Status CompileAndPrintLlvmIr(const std::string& hlo_text,
49 bool generate_ptx, int sm) {
50 TF_ASSIGN_OR_RETURN(
51 std::unique_ptr<xla::HloModule> hlo_module,
52 xla::LoadModuleFromData(/*data=*/hlo_text, /*format=*/"hlo"));
53 llvm::LLVMContext llvm_context;
54 // For now we pretend we're compiling for V100. This can be generalized
55 // later.
56
57 xla::gpu::GpuDeviceInfo gpu_device_info{};
58 gpu_device_info.threads_per_block_limit = 1024;
59 gpu_device_info.threads_per_warp = 32;
60 gpu_device_info.shared_memory_per_block = 49152;
61 gpu_device_info.core_count = 80;
62 gpu_device_info.threads_per_core_limit = 2048;
63 gpu_device_info.block_dim_limit_x = 2147483647;
64 gpu_device_info.block_dim_limit_y = 65535;
65 gpu_device_info.block_dim_limit_z = 65535;
66
67 tensorflow::se::CudaComputeCapability cuda_compute_capability;
68 cuda_compute_capability.major = sm / 10;
69 cuda_compute_capability.minor = sm % 10;
70 tensorflow::se::RocmComputeCapability rocm_compute_capability("gfx908");
71 std::string target_triple = "nvptx64-nvidia-cuda";
72 std::string datalayout = "nvptx64-nvidia-cuda";
73 std::string platform_name = "CUDA";
74 stream_executor::Platform::Id platform_id =
75 stream_executor::cuda::kCudaPlatformId;
76 TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::Module> llvm_module,
77 xla::gpu::CompileModuleToLlvmIr(
78 hlo_module.get(), &llvm_context,
79 /*target_triple=*/xla::gpu::nvptx::TargetTriple(),
80 /*data_layout=*/xla::gpu::nvptx::DataLayout(),
81 /*platform_name=*/platform_name,
82 /*platform_id=*/platform_id, gpu_device_info,
83 cuda_compute_capability, rocm_compute_capability,
84 /*pointer_size=*/8));
85
86 if (!generate_ptx) {
87 llvm_module->print(llvm::outs(), nullptr);
88 } else {
89 #if GOOGLE_CUDA
90 std::string libdevice_dir = xla::gpu::GetLibdeviceDir(hlo_module->config());
91 TF_ASSIGN_OR_RETURN(std::string ptx,
92 xla::gpu::nvptx::CompileToPtx(
93 llvm_module.get(), cuda_compute_capability,
94 hlo_module->config(), libdevice_dir));
95 std::cout << ptx << std::endl;
96 #else
97 return {tensorflow::error::UNIMPLEMENTED,
98 "Feature not yet implemented in ROCm"};
99 #endif
100 }
101 return xla::OkStatus();
102 }
103
CompileAndPrintLlvmIrFromFile(const std::string & file_name,bool ptx,int sm)104 xla::Status CompileAndPrintLlvmIrFromFile(const std::string& file_name,
105 bool ptx, int sm) {
106 std::string full_text;
107 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
108 file_name, &full_text));
109
110 std::vector<std::string> hlo_module_texts =
111 absl::StrSplit(full_text, "// -----");
112 for (const std::string& hlo_module_text : hlo_module_texts) {
113 TF_RETURN_IF_ERROR(CompileAndPrintLlvmIr(hlo_module_text, ptx, sm));
114 }
115
116 return xla::OkStatus();
117 }
118 } // namespace
119
main(int argc,char ** argv)120 int main(int argc, char** argv) {
121 bool ptx = false;
122 int sm = 70;
123 std::vector<tensorflow::Flag> flag_list;
124 xla::AppendDebugOptionsFlags(&flag_list);
125 flag_list.emplace_back("ptx", &ptx,
126 "Print PTX instead of not optimized LLVM.");
127 flag_list.emplace_back("sm", &sm,
128 "Specify the SM to target (useful only with --ptx).");
129 // The usage string includes the message at the top of the file, the
130 // DebugOptions flags and the flags defined above.
131 const std::string kUsageString = absl::StrCat(
132 kUsage, "\n\n", tensorflow::Flags::Usage(argv[0], flag_list));
133 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
134 tensorflow::port::InitMain(kUsageString.c_str(), &argc, &argv);
135 if (!parse_ok) {
136 LOG(QFATAL) << kUsageString;
137 }
138
139 QCHECK(argc == 2) << "Must specify a single input file";
140 TF_CHECK_OK(CompileAndPrintLlvmIrFromFile(argv[1], ptx, sm));
141
142 return 0;
143 }
144