• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===--- ROCm.h - ROCm installation detector --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
10 #define LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
11 
12 #include "clang/Basic/Cuda.h"
13 #include "clang/Basic/LLVM.h"
14 #include "clang/Driver/Driver.h"
15 #include "clang/Driver/Options.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/ADT/Triple.h"
19 #include "llvm/Option/ArgList.h"
20 #include "llvm/Support/VersionTuple.h"
21 
22 namespace clang {
23 namespace driver {
24 
25 /// A class to find a viable ROCM installation
26 /// TODO: Generalize to handle libclc.
27 class RocmInstallationDetector {
28 private:
29   struct ConditionalLibrary {
30     SmallString<0> On;
31     SmallString<0> Off;
32 
isValidConditionalLibrary33     bool isValid() const { return !On.empty() && !Off.empty(); }
34 
getConditionalLibrary35     StringRef get(bool Enabled) const {
36       assert(isValid());
37       return Enabled ? On : Off;
38     }
39   };
40 
41   // Installation path candidate.
42   struct Candidate {
43     llvm::SmallString<0> Path;
44     bool StrictChecking;
45 
46     Candidate(std::string Path, bool StrictChecking = false)
PathCandidate47         : Path(Path), StrictChecking(StrictChecking) {}
48   };
49 
50   const Driver &D;
51   bool HasHIPRuntime = false;
52   bool HasDeviceLibrary = false;
53 
54   // Default version if not detected or specified.
55   const unsigned DefaultVersionMajor = 3;
56   const unsigned DefaultVersionMinor = 5;
57   const char *DefaultVersionPatch = "0";
58 
59   // The version string in Major.Minor.Patch format.
60   std::string DetectedVersion;
61   // Version containing major and minor.
62   llvm::VersionTuple VersionMajorMinor;
63   // Version containing patch.
64   std::string VersionPatch;
65 
66   // ROCm path specified by --rocm-path.
67   StringRef RocmPathArg;
68   // ROCm device library paths specified by --rocm-device-lib-path.
69   std::vector<std::string> RocmDeviceLibPathArg;
70   // HIP version specified by --hip-version.
71   StringRef HIPVersionArg;
72   // Wheter -nogpulib is specified.
73   bool NoBuiltinLibs = false;
74 
75   // Paths
76   SmallString<0> InstallPath;
77   SmallString<0> BinPath;
78   SmallString<0> LibPath;
79   SmallString<0> LibDevicePath;
80   SmallString<0> IncludePath;
81   llvm::StringMap<std::string> LibDeviceMap;
82 
83   // Libraries that are always linked.
84   SmallString<0> OCML;
85   SmallString<0> OCKL;
86 
87   // Libraries that are always linked depending on the language
88   SmallString<0> OpenCL;
89   SmallString<0> HIP;
90 
91   // Libraries swapped based on compile flags.
92   ConditionalLibrary WavefrontSize64;
93   ConditionalLibrary FiniteOnly;
94   ConditionalLibrary UnsafeMath;
95   ConditionalLibrary DenormalsAreZero;
96   ConditionalLibrary CorrectlyRoundedSqrt;
97 
allGenericLibsValid()98   bool allGenericLibsValid() const {
99     return !OCML.empty() && !OCKL.empty() && !OpenCL.empty() && !HIP.empty() &&
100            WavefrontSize64.isValid() && FiniteOnly.isValid() &&
101            UnsafeMath.isValid() && DenormalsAreZero.isValid() &&
102            CorrectlyRoundedSqrt.isValid();
103   }
104 
105   void scanLibDevicePath(llvm::StringRef Path);
106   void ParseHIPVersionFile(llvm::StringRef V);
107   SmallVector<Candidate, 4> getInstallationPathCandidates();
108 
109 public:
110   RocmInstallationDetector(const Driver &D, const llvm::Triple &HostTriple,
111                            const llvm::opt::ArgList &Args,
112                            bool DetectHIPRuntime = true,
113                            bool DetectDeviceLib = false);
114 
115   /// Add arguments needed to link default bitcode libraries.
116   void addCommonBitcodeLibCC1Args(const llvm::opt::ArgList &DriverArgs,
117                                   llvm::opt::ArgStringList &CC1Args,
118                                   StringRef LibDeviceFile, bool Wave64,
119                                   bool DAZ, bool FiniteOnly, bool UnsafeMathOpt,
120                                   bool FastRelaxedMath, bool CorrectSqrt) const;
121 
122   /// Check whether we detected a valid HIP runtime.
hasHIPRuntime()123   bool hasHIPRuntime() const { return HasHIPRuntime; }
124 
125   /// Check whether we detected a valid ROCm device library.
hasDeviceLibrary()126   bool hasDeviceLibrary() const { return HasDeviceLibrary; }
127 
128   /// Print information about the detected ROCm installation.
129   void print(raw_ostream &OS) const;
130 
131   /// Get the detected Rocm install's version.
132   // RocmVersion version() const { return Version; }
133 
134   /// Get the detected Rocm installation path.
getInstallPath()135   StringRef getInstallPath() const { return InstallPath; }
136 
137   /// Get the detected path to Rocm's bin directory.
138   // StringRef getBinPath() const { return BinPath; }
139 
140   /// Get the detected Rocm Include path.
getIncludePath()141   StringRef getIncludePath() const { return IncludePath; }
142 
143   /// Get the detected Rocm library path.
getLibPath()144   StringRef getLibPath() const { return LibPath; }
145 
146   /// Get the detected Rocm device library path.
getLibDevicePath()147   StringRef getLibDevicePath() const { return LibDevicePath; }
148 
getOCMLPath()149   StringRef getOCMLPath() const {
150     assert(!OCML.empty());
151     return OCML;
152   }
153 
getOCKLPath()154   StringRef getOCKLPath() const {
155     assert(!OCKL.empty());
156     return OCKL;
157   }
158 
getOpenCLPath()159   StringRef getOpenCLPath() const {
160     assert(!OpenCL.empty());
161     return OpenCL;
162   }
163 
getHIPPath()164   StringRef getHIPPath() const {
165     assert(!HIP.empty());
166     return HIP;
167   }
168 
getWavefrontSize64Path(bool Enabled)169   StringRef getWavefrontSize64Path(bool Enabled) const {
170     return WavefrontSize64.get(Enabled);
171   }
172 
getFiniteOnlyPath(bool Enabled)173   StringRef getFiniteOnlyPath(bool Enabled) const {
174     return FiniteOnly.get(Enabled);
175   }
176 
getUnsafeMathPath(bool Enabled)177   StringRef getUnsafeMathPath(bool Enabled) const {
178     return UnsafeMath.get(Enabled);
179   }
180 
getDenormalsAreZeroPath(bool Enabled)181   StringRef getDenormalsAreZeroPath(bool Enabled) const {
182     return DenormalsAreZero.get(Enabled);
183   }
184 
getCorrectlyRoundedSqrtPath(bool Enabled)185   StringRef getCorrectlyRoundedSqrtPath(bool Enabled) const {
186     return CorrectlyRoundedSqrt.get(Enabled);
187   }
188 
189   /// Get libdevice file for given architecture
getLibDeviceFile(StringRef Gpu)190   std::string getLibDeviceFile(StringRef Gpu) const {
191     return LibDeviceMap.lookup(Gpu);
192   }
193 
194   void AddHIPIncludeArgs(const llvm::opt::ArgList &DriverArgs,
195                          llvm::opt::ArgStringList &CC1Args) const;
196 
197   void detectDeviceLibrary();
198   void detectHIPRuntime();
199 
200   /// Get the values for --rocm-device-lib-path arguments
getRocmDeviceLibPathArg()201   std::vector<std::string> getRocmDeviceLibPathArg() const {
202     return RocmDeviceLibPathArg;
203   }
204 
205   /// Get the value for --rocm-path argument
getRocmPathArg()206   StringRef getRocmPathArg() const { return RocmPathArg; }
207 
208   /// Get the value for --hip-version argument
getHIPVersionArg()209   StringRef getHIPVersionArg() const { return HIPVersionArg; }
210 
getHIPVersion()211   std::string getHIPVersion() const { return DetectedVersion; }
212 };
213 
214 } // end namespace driver
215 } // end namespace clang
216 
217 #endif // LLVM_CLANG_LIB_DRIVER_TOOLCHAINS_ROCM_H
218