• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for getting system information during TensorFlow tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23import shlex
24import subprocess
25import tempfile
26import time
27
28from tensorflow.core.util import test_log_pb2
29from tensorflow.python.platform import gfile
30from tensorflow.tools.test import gpu_info_lib
31from tensorflow.tools.test import system_info_lib
32
33
34class MissingLogsError(Exception):
35  pass
36
37
38def get_git_commit_sha():
39  """Get git commit SHA for this build.
40
41  Attempt to get the SHA from environment variable GIT_COMMIT, which should
42  be available on Jenkins build agents.
43
44  Returns:
45    SHA hash of the git commit used for the build, if available
46  """
47
48  return os.getenv("GIT_COMMIT")
49
50
51def process_test_logs(name, test_name, test_args, benchmark_type,
52                      start_time, run_time, log_files):
53  """Gather test information and put it in a TestResults proto.
54
55  Args:
56    name: Benchmark target identifier.
57    test_name: A unique bazel target, e.g. "//path/to:test"
58    test_args: A string containing all arguments to run the target with.
59    benchmark_type: A string representing the BenchmarkType enum; the
60      benchmark type for this target.
61    start_time: Test starting time (epoch)
62    run_time:   Wall time that the test ran for
63    log_files:  Paths to the log files
64
65  Returns:
66    A TestResults proto
67  """
68
69  results = test_log_pb2.TestResults()
70  results.name = name
71  results.target = test_name
72  results.start_time = start_time
73  results.run_time = run_time
74  results.benchmark_type = test_log_pb2.TestResults.BenchmarkType.Value(
75      benchmark_type.upper())
76
77  # Gather source code information
78  git_sha = get_git_commit_sha()
79  if git_sha:
80    results.commit_id.hash = git_sha
81
82  results.entries.CopyFrom(process_benchmarks(log_files))
83  results.run_configuration.argument.extend(test_args)
84  results.machine_configuration.CopyFrom(
85      system_info_lib.gather_machine_configuration())
86  return results
87
88
89def process_benchmarks(log_files):
90  benchmarks = test_log_pb2.BenchmarkEntries()
91  for f in log_files:
92    content = gfile.GFile(f, "rb").read()
93    if benchmarks.MergeFromString(content) != len(content):
94      raise Exception("Failed parsing benchmark entry from %s" % f)
95  return benchmarks
96
97
98def run_and_gather_logs(name, test_name, test_args,
99                        benchmark_type):
100  """Run the bazel test given by test_name.  Gather and return the logs.
101
102  Args:
103    name: Benchmark target identifier.
104    test_name: A unique bazel target, e.g. "//path/to:test"
105    test_args: A string containing all arguments to run the target with.
106    benchmark_type: A string representing the BenchmarkType enum; the
107      benchmark type for this target.
108
109  Returns:
110    A tuple (test_results, mangled_test_name), where
111    test_results: A test_log_pb2.TestResults proto
112    test_adjusted_name: Unique benchmark name that consists of
113      benchmark name optionally followed by GPU type.
114
115  Raises:
116    ValueError: If the test_name is not a valid target.
117    subprocess.CalledProcessError: If the target itself fails.
118    IOError: If there are problems gathering test log output from the test.
119    MissingLogsError: If we couldn't find benchmark logs.
120  """
121  if not (test_name and test_name.startswith("//") and ".." not in test_name and
122          not test_name.endswith(":") and not test_name.endswith(":all") and
123          not test_name.endswith("...") and len(test_name.split(":")) == 2):
124    raise ValueError("Expected test_name parameter with a unique test, e.g.: "
125                     "--test_name=//path/to:test")
126  test_executable = test_name.rstrip().strip("/").replace(":", "/")
127
128  if gfile.Exists(os.path.join("bazel-bin", test_executable)):
129    # Running in standalone mode from core of the repository
130    test_executable = os.path.join("bazel-bin", test_executable)
131  else:
132    # Hopefully running in sandboxed mode
133    test_executable = os.path.join(".", test_executable)
134
135  test_adjusted_name = name
136  gpu_config = gpu_info_lib.gather_gpu_devices()
137  if gpu_config:
138    gpu_name = gpu_config[0].model
139    gpu_short_name_match = re.search(r"Tesla (K40|K80|P100|V100)", gpu_name)
140    if gpu_short_name_match:
141      gpu_short_name = gpu_short_name_match.group(0)
142      test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")
143
144  temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs")
145  mangled_test_name = (test_adjusted_name.strip("/")
146                       .replace("|", "_").replace("/", "_").replace(":", "_"))
147  test_file_prefix = os.path.join(temp_directory, mangled_test_name)
148  test_file_prefix = "%s." % test_file_prefix
149
150  try:
151    if not gfile.Exists(test_executable):
152      raise ValueError("Executable does not exist: %s" % test_executable)
153    test_args = shlex.split(test_args)
154
155    # This key is defined in tf/core/util/reporter.h as
156    # TestReporter::kTestReporterEnv.
157    os.environ["TEST_REPORT_FILE_PREFIX"] = test_file_prefix
158    start_time = time.time()
159    subprocess.check_call([test_executable] + test_args)
160    run_time = time.time() - start_time
161    log_files = gfile.Glob("{}*".format(test_file_prefix))
162    if not log_files:
163      raise MissingLogsError("No log files found at %s." % test_file_prefix)
164
165    return (process_test_logs(
166        test_adjusted_name,
167        test_name=test_name,
168        test_args=test_args,
169        benchmark_type=benchmark_type,
170        start_time=int(start_time),
171        run_time=run_time,
172        log_files=log_files), test_adjusted_name)
173
174  finally:
175    try:
176      gfile.DeleteRecursively(temp_directory)
177    except OSError:
178      pass
179