• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- Python -*-
2
3load("//tensorflow:tensorflow.bzl", "tf_py_test")
4
5# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test)
6def tf_cc_logged_benchmark(
7        name = None,
8        target = None,
9        benchmarks = "..",
10        tags = [],
11        test_log_output_prefix = "",
12        benchmark_type = "cpp_microbenchmark"):
13    if not name:
14        fail("Must provide a name")
15    if not target:
16        fail("Must provide a target")
17    if (not ":" in target or
18        not target.startswith("//") or
19        target.endswith(":all") or
20        target.endswith(".")):
21        fail(" ".join((
22            "Target must be a single well-defined test, e.g.,",
23            "//path/to:test. Received: %s" % target,
24        )))
25
26    all_tags = (
27        depset(tags) + depset(
28            ["benchmark-test", "local", "manual", "regression-test"],
29        )
30    ).to_list()
31
32    tf_py_test(
33        name = name,
34        tags = all_tags,
35        size = "large",
36        srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
37        args = [
38            "--name=//%s:%s" % (native.package_name(), name),
39            "--test_name=" + target,
40            "--test_args=--benchmarks=%s" % benchmarks,
41            "--benchmark_type=%s" % benchmark_type,
42        ],
43        data = [
44            target,
45        ],
46        main = "run_and_gather_logs.py",
47        additional_deps = [
48            "//tensorflow/tools/test:run_and_gather_logs",
49        ],
50    )
51
52# Create a benchmark test target of a TensorFlow python test (*py_tests)
53def tf_py_logged_benchmark(
54        name = None,
55        target = None,
56        benchmarks = "..",
57        tags = [],
58        test_log_output_prefix = ""):
59    # For now generating a py benchmark is the same as generating a C++
60    # benchmark target. In the future this may change, so we have
61    # two macros just in case
62    tf_cc_logged_benchmark(
63        name = name,
64        target = target,
65        benchmarks = benchmarks,
66        tags = tags,
67        test_log_output_prefix = test_log_output_prefix,
68        benchmark_type = "python_benchmark",
69    )
70