1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test runner for TensorFlow tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import os 23import shlex 24from string import maketrans 25import sys 26import time 27 28from google.protobuf import json_format 29from google.protobuf import text_format 30 31from tensorflow.core.util import test_log_pb2 32from tensorflow.python.platform import app 33from tensorflow.python.platform import gfile 34from tensorflow.python.platform import test 35from tensorflow.python.platform import tf_logging 36from tensorflow.tools.test import run_and_gather_logs_lib 37 38# pylint: disable=g-import-not-at-top 39# pylint: disable=g-bad-import-order 40# pylint: disable=unused-import 41# Note: cpuinfo and psutil are not installed for you in the TensorFlow 42# OSS tree. They are installable via pip. 43try: 44 import cpuinfo 45 import psutil 46except ImportError as e: 47 tf_logging.error("\n\n\nERROR: Unable to import necessary library: {}. " 48 "Issuing a soft exit.\n\n\n".format(e)) 49 sys.exit(0) 50# pylint: enable=g-bad-import-order 51# pylint: enable=unused-import 52 53FLAGS = None 54 55 56def gather_build_configuration(): 57 build_config = test_log_pb2.BuildConfiguration() 58 build_config.mode = FLAGS.compilation_mode 59 # Include all flags except includes 60 cc_flags = [ 61 flag for flag in shlex.split(FLAGS.cc_flags) if not flag.startswith("-i") 62 ] 63 build_config.cc_flags.extend(cc_flags) 64 return build_config 65 66 67def main(unused_args): 68 name = FLAGS.name 69 test_name = FLAGS.test_name 70 test_args = FLAGS.test_args 71 benchmark_type = FLAGS.benchmark_type 72 test_results, _ = run_and_gather_logs_lib.run_and_gather_logs( 73 name, test_name=test_name, test_args=test_args, 74 benchmark_type=benchmark_type) 75 76 # Additional bits we receive from bazel 77 test_results.build_configuration.CopyFrom(gather_build_configuration()) 78 79 if not FLAGS.test_log_output_dir: 80 print(text_format.MessageToString(test_results)) 81 return 82 83 if FLAGS.test_log_output_filename: 84 file_name = FLAGS.test_log_output_filename 85 else: 86 file_name = (name.strip("/").translate(maketrans("/:", "__")) + 87 time.strftime("%Y%m%d%H%M%S", time.gmtime())) 88 if FLAGS.test_log_output_use_tmpdir: 89 tmpdir = test.get_temp_dir() 90 output_path = os.path.join(tmpdir, FLAGS.test_log_output_dir, file_name) 91 else: 92 output_path = os.path.join( 93 os.path.abspath(FLAGS.test_log_output_dir), file_name) 94 json_test_results = json_format.MessageToJson(test_results) 95 gfile.GFile(output_path + ".json", "w").write(json_test_results) 96 tf_logging.info("Test results written to: %s" % output_path) 97 98 99if __name__ == "__main__": 100 parser = argparse.ArgumentParser() 101 parser.register( 102 "type", "bool", lambda v: v.lower() in ("true", "t", "y", "yes")) 103 parser.add_argument( 104 "--name", type=str, default="", help="Benchmark target identifier.") 105 parser.add_argument( 106 "--test_name", type=str, default="", help="Test target to run.") 107 parser.add_argument( 108 "--benchmark_type", 109 type=str, 110 default="", 111 help="BenchmarkType enum string (benchmark type).") 112 parser.add_argument( 113 "--test_args", 114 type=str, 115 default="", 116 help="Test arguments, space separated.") 117 parser.add_argument( 118 "--test_log_output_use_tmpdir", 119 type="bool", 120 nargs="?", 121 const=True, 122 default=False, 123 help="Store the log output into tmpdir?") 124 parser.add_argument( 125 "--compilation_mode", 126 type=str, 127 default="", 128 help="Mode used during this build (e.g. opt, dbg).") 129 parser.add_argument( 130 "--cc_flags", 131 type=str, 132 default="", 133 help="CC flags used during this build.") 134 parser.add_argument( 135 "--test_log_output_dir", 136 type=str, 137 default="", 138 help="Directory to write benchmark results to.") 139 parser.add_argument( 140 "--test_log_output_filename", 141 type=str, 142 default="", 143 help="Filename to output benchmark results to. If the filename is not " 144 "specified, it will be automatically created based on --name " 145 "and current time.") 146 FLAGS, unparsed = parser.parse_known_args() 147 app.run(main=main, argv=[sys.argv[0]] + unparsed) 148