• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Test runner for TensorFlow tests."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23import os
24import shlex
25import sys
26import time
27
28from absl import app
29import six
30
31from google.protobuf import json_format
32from google.protobuf import text_format
33from tensorflow.core.util import test_log_pb2
34from tensorflow.python.platform import gfile
35from tensorflow.python.platform import test
36from tensorflow.python.platform import tf_logging
37from tensorflow.tools.test import run_and_gather_logs_lib
38
39# pylint: disable=g-import-not-at-top
40# pylint: disable=g-bad-import-order
41# pylint: disable=unused-import
42# Note: cpuinfo and psutil are not installed for you in the TensorFlow
43# OSS tree.  They are installable via pip.
44try:
45  import cpuinfo
46  import psutil
47except ImportError as e:
48  tf_logging.error("\n\n\nERROR: Unable to import necessary library: {}.  "
49                   "Issuing a soft exit.\n\n\n".format(e))
50  sys.exit(0)
51# pylint: enable=g-bad-import-order
52# pylint: enable=unused-import
53
54FLAGS = None
55
56
57def gather_build_configuration():
58  build_config = test_log_pb2.BuildConfiguration()
59  build_config.mode = FLAGS.compilation_mode
60  # Include all flags except includes
61  cc_flags = [
62      flag for flag in shlex.split(FLAGS.cc_flags) if not flag.startswith("-i")
63  ]
64  build_config.cc_flags.extend(cc_flags)
65  return build_config
66
67
68def main(unused_args):
69  name = FLAGS.name
70  test_name = FLAGS.test_name
71  test_args = FLAGS.test_args
72  benchmark_type = FLAGS.benchmark_type
73  test_results, _ = run_and_gather_logs_lib.run_and_gather_logs(
74      name, test_name=test_name, test_args=test_args,
75      benchmark_type=benchmark_type)
76
77  # Additional bits we receive from bazel
78  test_results.build_configuration.CopyFrom(gather_build_configuration())
79
80  if not FLAGS.test_log_output_dir:
81    print(text_format.MessageToString(test_results))
82    return
83
84  if FLAGS.test_log_output_filename:
85    file_name = FLAGS.test_log_output_filename
86  else:
87    file_name = (
88        six.ensure_str(name).strip("/").translate(str.maketrans("/:", "__")) +
89        time.strftime("%Y%m%d%H%M%S", time.gmtime()))
90  if FLAGS.test_log_output_use_tmpdir:
91    tmpdir = test.get_temp_dir()
92    output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name)
93  else:
94    output_path = os.path.join(
95        os.path.abspath(FLAGS.test_log_output_dir), file_name)
96  json_test_results = json_format.MessageToJson(test_results)
97  gfile.GFile(six.ensure_str(output_path) + ".json",
98              "w").write(json_test_results)
99  tf_logging.info("Test results written to: %s" % output_path)
100
101
102if __name__ == "__main__":
103  parser = argparse.ArgumentParser()
104  parser.register(
105      "type", "bool", lambda v: v.lower() in ("true", "t", "y", "yes"))
106  parser.add_argument(
107      "--name", type=str, default="", help="Benchmark target identifier.")
108  parser.add_argument(
109      "--test_name", type=str, default="", help="Test target to run.")
110  parser.add_argument(
111      "--benchmark_type",
112      type=str,
113      default="",
114      help="BenchmarkType enum string (benchmark type).")
115  parser.add_argument(
116      "--test_args",
117      type=str,
118      default="",
119      help="Test arguments, space separated.")
120  parser.add_argument(
121      "--test_log_output_use_tmpdir",
122      type="bool",
123      nargs="?",
124      const=True,
125      default=False,
126      help="Store the log output into tmpdir?")
127  parser.add_argument(
128      "--compilation_mode",
129      type=str,
130      default="",
131      help="Mode used during this build (e.g. opt, dbg).")
132  parser.add_argument(
133      "--cc_flags",
134      type=str,
135      default="",
136      help="CC flags used during this build.")
137  parser.add_argument(
138      "--test_log_output_dir",
139      type=str,
140      default="",
141      help="Directory to write benchmark results to.")
142  parser.add_argument(
143      "--test_log_output_filename",
144      type=str,
145      default="",
146      help="Filename to output benchmark results to. If the filename is not "
147           "specified, it will be automatically created based on --name "
148           "and current time.")
149  FLAGS, unparsed = parser.parse_known_args()
150  app.run(main=main, argv=[sys.argv[0]] + unparsed)
151