• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Prints ROCm library and header directories and versions found on the system.
16
17The script searches for ROCm library and header files on the system, inspects
18them to determine their version and prints the configuration to stdout.
19The path to inspect is specified through an environment variable (ROCM_PATH).
20If no valid configuration is found, the script prints to stderr and
21returns an error code.
22
23The script takes the directory specified by the ROCM_PATH environment variable.
24The script looks for headers and library files in a hard-coded set of
25subdirectories from base path of the specified directory. If ROCM_PATH is not
26specified, then "/opt/rocm" is used as it default value
27
28"""
29
30import io
31import os
32import re
33import sys
34
35
36class ConfigError(Exception):
37  pass
38
39
40def _get_default_rocm_path():
41  return "/opt/rocm"
42
43
44def _get_rocm_install_path():
45  """Determines and returns the ROCm installation path."""
46  rocm_install_path = _get_default_rocm_path()
47  if "ROCM_PATH" in os.environ:
48    rocm_install_path = os.environ["ROCM_PATH"]
49  # rocm_install_path = os.path.realpath(rocm_install_path)
50  return rocm_install_path
51
52
53def _get_composite_version_number(major, minor, patch):
54  return 10000 * major + 100 * minor + patch
55
56
57def _get_header_version(path, name):
58  """Returns preprocessor defines in C header file."""
59  for line in io.open(path, "r", encoding="utf-8"):
60    match = re.match(r"#define %s +(\d+)" % name, line)
61    if match:
62      value = match.group(1)
63      return int(value)
64
65  raise ConfigError('#define "{}" is either\n'.format(name) +
66                    "  not present in file {} OR\n".format(path) +
67                    "  its value is not an integer literal")
68
69
70def _find_rocm_config(rocm_install_path):
71
72  def rocm_version_numbers(path):
73    version_file = os.path.join(path, ".info/version-dev")
74    if not os.path.exists(version_file):
75      raise ConfigError('ROCm version file "{}" not found'.format(version_file))
76    version_numbers = []
77    with open(version_file) as f:
78      version_string = f.read().strip()
79      version_numbers = version_string.split(".")
80    major = int(version_numbers[0])
81    minor = int(version_numbers[1])
82    patch = int(version_numbers[2].split("-")[0])
83    return major, minor, patch
84
85  major, minor, patch = rocm_version_numbers(rocm_install_path)
86
87  rocm_config = {
88      "rocm_version_number": _get_composite_version_number(major, minor, patch)
89  }
90
91  return rocm_config
92
93
94def _find_hipruntime_config(rocm_install_path):
95
96  def hipruntime_version_number(path):
97    version_file = os.path.join(path, "hip/include/hip/hip_version.h")
98    if not os.path.exists(version_file):
99      raise ConfigError(
100          'HIP Runtime version file "{}" not found'.format(version_file))
101    # This header file has an explicit #define for HIP_VERSION, whose value
102    # is (HIP_VERSION_MAJOR * 100 + HIP_VERSION_MINOR)
103    # Retreive the major + minor and re-calculate here, since we do not
104    # want get into the business of parsing arith exprs
105    major = _get_header_version(version_file, "HIP_VERSION_MAJOR")
106    minor = _get_header_version(version_file, "HIP_VERSION_MINOR")
107    return 100 * major + minor
108
109  hipruntime_config = {
110      "hipruntime_version_number": hipruntime_version_number(rocm_install_path)
111  }
112
113  return hipruntime_config
114
115
116def _find_miopen_config(rocm_install_path):
117
118  def miopen_version_numbers(path):
119    version_file = os.path.join(path, "miopen/include/miopen/version.h")
120    if not os.path.exists(version_file):
121      raise ConfigError(
122          'MIOpen version file "{}" not found'.format(version_file))
123    major = _get_header_version(version_file, "MIOPEN_VERSION_MAJOR")
124    minor = _get_header_version(version_file, "MIOPEN_VERSION_MINOR")
125    patch = _get_header_version(version_file, "MIOPEN_VERSION_PATCH")
126    return major, minor, patch
127
128  major, minor, patch = miopen_version_numbers(rocm_install_path)
129
130  miopen_config = {
131      "miopen_version_number":
132          _get_composite_version_number(major, minor, patch)
133  }
134
135  return miopen_config
136
137
138def _find_rocblas_config(rocm_install_path):
139
140  def rocblas_version_numbers(path):
141    possible_version_files = [
142        "rocblas/include/rocblas-version.h",  # ROCm 3.7 and prior
143        "rocblas/include/internal/rocblas-version.h",  # ROCm 3.8
144    ]
145    version_file = None
146    for f in possible_version_files:
147      version_file_path = os.path.join(path, f)
148      if os.path.exists(version_file_path):
149        version_file = version_file_path
150        break
151    if not version_file:
152      raise ConfigError(
153          "rocblas version file not found in {}".format(
154              possible_version_files))
155    major = _get_header_version(version_file, "ROCBLAS_VERSION_MAJOR")
156    minor = _get_header_version(version_file, "ROCBLAS_VERSION_MINOR")
157    patch = _get_header_version(version_file, "ROCBLAS_VERSION_PATCH")
158    return major, minor, patch
159
160  major, minor, patch = rocblas_version_numbers(rocm_install_path)
161
162  rocblas_config = {
163      "rocblas_version_number":
164          _get_composite_version_number(major, minor, patch)
165  }
166
167  return rocblas_config
168
169
170def _find_rocrand_config(rocm_install_path):
171
172  def rocrand_version_number(path):
173    version_file = os.path.join(path, "rocrand/include/rocrand_version.h")
174    if not os.path.exists(version_file):
175      raise ConfigError(
176          'rocblas version file "{}" not found'.format(version_file))
177    version_number = _get_header_version(version_file, "ROCRAND_VERSION")
178    return version_number
179
180  rocrand_config = {
181      "rocrand_version_number": rocrand_version_number(rocm_install_path)
182  }
183
184  return rocrand_config
185
186
187def _find_rocfft_config(rocm_install_path):
188
189  def rocfft_version_numbers(path):
190    version_file = os.path.join(path, "rocfft/include/rocfft-version.h")
191    if not os.path.exists(version_file):
192      raise ConfigError(
193          'rocfft version file "{}" not found'.format(version_file))
194    major = _get_header_version(version_file, "rocfft_version_major")
195    minor = _get_header_version(version_file, "rocfft_version_minor")
196    patch = _get_header_version(version_file, "rocfft_version_patch")
197    return major, minor, patch
198
199  major, minor, patch = rocfft_version_numbers(rocm_install_path)
200
201  rocfft_config = {
202      "rocfft_version_number":
203          _get_composite_version_number(major, minor, patch)
204  }
205
206  return rocfft_config
207
208
209def _find_hipfft_config(rocm_install_path):
210
211  def hipfft_version_numbers(path):
212    version_file = os.path.join(path, "hipfft/include/hipfft-version.h")
213    if not os.path.exists(version_file):
214      raise ConfigError(
215          'hipfft version file "{}" not found'.format(version_file))
216    major = _get_header_version(version_file, "hipfftVersionMajor")
217    minor = _get_header_version(version_file, "hipfftVersionMinor")
218    patch = _get_header_version(version_file, "hipfftVersionPatch")
219    return major, minor, patch
220
221  major, minor, patch = hipfft_version_numbers(rocm_install_path)
222
223  hipfft_config = {
224      "hipfft_version_number":
225          _get_composite_version_number(major, minor, patch)
226  }
227
228  return hipfft_config
229
230
231def _find_roctracer_config(rocm_install_path):
232
233  def roctracer_version_numbers(path):
234    version_file = os.path.join(path, "roctracer/include/roctracer.h")
235    if not os.path.exists(version_file):
236      raise ConfigError(
237          'roctracer version file "{}" not found'.format(version_file))
238    major = _get_header_version(version_file, "ROCTRACER_VERSION_MAJOR")
239    minor = _get_header_version(version_file, "ROCTRACER_VERSION_MINOR")
240    # roctracer header does not have a patch version number
241    patch = 0
242    return major, minor, patch
243
244  major, minor, patch = roctracer_version_numbers(rocm_install_path)
245
246  roctracer_config = {
247      "roctracer_version_number":
248          _get_composite_version_number(major, minor, patch)
249  }
250
251  return roctracer_config
252
253
254def _find_hipsparse_config(rocm_install_path):
255
256  def hipsparse_version_numbers(path):
257    version_file = os.path.join(path, "hipsparse/include/hipsparse-version.h")
258    if not os.path.exists(version_file):
259      raise ConfigError(
260          'hipsparse version file "{}" not found'.format(version_file))
261    major = _get_header_version(version_file, "hipsparseVersionMajor")
262    minor = _get_header_version(version_file, "hipsparseVersionMinor")
263    patch = _get_header_version(version_file, "hipsparseVersionPatch")
264    return major, minor, patch
265
266  major, minor, patch = hipsparse_version_numbers(rocm_install_path)
267
268  hipsparse_config = {
269      "hipsparse_version_number":
270          _get_composite_version_number(major, minor, patch)
271  }
272
273  return hipsparse_config
274
275
276def _find_rocsolver_config(rocm_install_path):
277
278  def rocsolver_version_numbers(path):
279    version_file = os.path.join(path, "rocsolver/include/rocsolver-version.h")
280    if not os.path.exists(version_file):
281      raise ConfigError(
282          'rocsolver version file "{}" not found'.format(version_file))
283    major = _get_header_version(version_file, "ROCSOLVER_VERSION_MAJOR")
284    minor = _get_header_version(version_file, "ROCSOLVER_VERSION_MINOR")
285    patch = _get_header_version(version_file, "ROCSOLVER_VERSION_PATCH")
286    return major, minor, patch
287
288  major, minor, patch = rocsolver_version_numbers(rocm_install_path)
289
290  rocsolver_config = {
291      "rocsolver_version_number":
292          _get_composite_version_number(major, minor, patch)
293  }
294
295  return rocsolver_config
296
297
298def find_rocm_config():
299  """Returns a dictionary of ROCm components config info."""
300  rocm_install_path = _get_rocm_install_path()
301  if not os.path.exists(rocm_install_path):
302    raise ConfigError(
303        'Specified ROCM_PATH "{}" does not exist'.format(rocm_install_path))
304
305  result = {}
306
307  result["rocm_toolkit_path"] = rocm_install_path
308  result.update(_find_rocm_config(rocm_install_path))
309  result.update(_find_hipruntime_config(rocm_install_path))
310  result.update(_find_miopen_config(rocm_install_path))
311  result.update(_find_rocblas_config(rocm_install_path))
312  result.update(_find_rocrand_config(rocm_install_path))
313  result.update(_find_rocfft_config(rocm_install_path))
314  if result["rocm_version_number"] >= 40100:
315    result.update(_find_hipfft_config(rocm_install_path))
316  result.update(_find_roctracer_config(rocm_install_path))
317  result.update(_find_hipsparse_config(rocm_install_path))
318  result.update(_find_rocsolver_config(rocm_install_path))
319
320  return result
321
322
323def main():
324  try:
325    for key, value in sorted(find_rocm_config().items()):
326      print("%s: %s" % (key, value))
327  except ConfigError as e:
328    sys.stderr.write("\nERROR: {}\n\n".format(str(e)))
329    sys.exit(1)
330
331
332if __name__ == "__main__":
333  main()
334