1# Lint as: python2, python3 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Test runner for TensorFlow tests.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23import os 24import shlex 25import sys 26import time 27 28from absl import app 29import six 30 31from google.protobuf import json_format 32from google.protobuf import text_format 33from tensorflow.core.util import test_log_pb2 34from tensorflow.python.platform import gfile 35from tensorflow.python.platform import test 36from tensorflow.python.platform import tf_logging 37from tensorflow.tools.test import run_and_gather_logs_lib 38 39# pylint: disable=g-import-not-at-top 40# pylint: disable=g-bad-import-order 41# pylint: disable=unused-import 42# Note: cpuinfo and psutil are not installed for you in the TensorFlow 43# OSS tree. They are installable via pip. 44try: 45 import cpuinfo 46 import psutil 47except ImportError as e: 48 tf_logging.error("\n\n\nERROR: Unable to import necessary library: {}. " 49 "Issuing a soft exit.\n\n\n".format(e)) 50 sys.exit(0) 51# pylint: enable=g-bad-import-order 52# pylint: enable=unused-import 53 54FLAGS = None 55 56 57def gather_build_configuration(): 58 build_config = test_log_pb2.BuildConfiguration() 59 build_config.mode = FLAGS.compilation_mode 60 # Include all flags except includes 61 cc_flags = [ 62 flag for flag in shlex.split(FLAGS.cc_flags) if not flag.startswith("-i") 63 ] 64 build_config.cc_flags.extend(cc_flags) 65 return build_config 66 67 68def main(unused_args): 69 name = FLAGS.name 70 test_name = FLAGS.test_name 71 test_args = FLAGS.test_args 72 benchmark_type = FLAGS.benchmark_type 73 test_results, _ = run_and_gather_logs_lib.run_and_gather_logs( 74 name, test_name=test_name, test_args=test_args, 75 benchmark_type=benchmark_type) 76 77 # Additional bits we receive from bazel 78 test_results.build_configuration.CopyFrom(gather_build_configuration()) 79 80 if not FLAGS.test_log_output_dir: 81 print(text_format.MessageToString(test_results)) 82 return 83 84 if FLAGS.test_log_output_filename: 85 file_name = FLAGS.test_log_output_filename 86 else: 87 file_name = ( 88 six.ensure_str(name).strip("/").translate(str.maketrans("/:", "__")) + 89 time.strftime("%Y%m%d%H%M%S", time.gmtime())) 90 if FLAGS.test_log_output_use_tmpdir: 91 tmpdir = test.get_temp_dir() 92 output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name) 93 else: 94 output_path = os.path.join( 95 os.path.abspath(FLAGS.test_log_output_dir), file_name) 96 json_test_results = json_format.MessageToJson(test_results) 97 gfile.GFile(six.ensure_str(output_path) + ".json", 98 "w").write(json_test_results) 99 tf_logging.info("Test results written to: %s" % output_path) 100 101 102if __name__ == "__main__": 103 parser = argparse.ArgumentParser() 104 parser.register( 105 "type", "bool", lambda v: v.lower() in ("true", "t", "y", "yes")) 106 parser.add_argument( 107 "--name", type=str, default="", help="Benchmark target identifier.") 108 parser.add_argument( 109 "--test_name", type=str, default="", help="Test target to run.") 110 parser.add_argument( 111 "--benchmark_type", 112 type=str, 113 default="", 114 help="BenchmarkType enum string (benchmark type).") 115 parser.add_argument( 116 "--test_args", 117 type=str, 118 default="", 119 help="Test arguments, space separated.") 120 parser.add_argument( 121 "--test_log_output_use_tmpdir", 122 type="bool", 123 nargs="?", 124 const=True, 125 default=False, 126 help="Store the log output into tmpdir?") 127 parser.add_argument( 128 "--compilation_mode", 129 type=str, 130 default="", 131 help="Mode used during this build (e.g. opt, dbg).") 132 parser.add_argument( 133 "--cc_flags", 134 type=str, 135 default="", 136 help="CC flags used during this build.") 137 parser.add_argument( 138 "--test_log_output_dir", 139 type=str, 140 default="", 141 help="Directory to write benchmark results to.") 142 parser.add_argument( 143 "--test_log_output_filename", 144 type=str, 145 default="", 146 help="Filename to output benchmark results to. If the filename is not " 147 "specified, it will be automatically created based on --name " 148 "and current time.") 149 FLAGS, unparsed = parser.parse_known_args() 150 app.run(main=main, argv=[sys.argv[0]] + unparsed) 151