• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Build rules for XLA testing."""
2
3load("//tensorflow:tensorflow.bzl", "tf_cc_test")
4load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
5load(
6    "//tensorflow/stream_executor:build_defs.bzl",
7    "if_gpu_is_configured",
8)
9load(
10    "//tensorflow/core/platform:build_config_root.bzl",
11    "tf_gpu_tests_tags",
12)
13
14all_backends = ["cpu", "gpu"] + plugins.keys()
15
16def xla_test(
17        name,
18        srcs,
19        deps,
20        xla_test_library_deps = [],
21        backends = [],
22        disabled_backends = [],
23        real_hardware_only = False,
24        args = [],
25        tags = [],
26        copts = [],
27        data = [],
28        backend_tags = {},
29        backend_args = {},
30        **kwargs):
31    """Generates cc_test targets for the given XLA backends.
32
33    This rule generates a cc_test target for one or more XLA backends and also a
34    platform-agnostic cc_library rule. The arguments are identical to cc_test with
35    two additions: 'backends' and 'backend_args'. 'backends' specifies the
36    backends to generate tests for ("cpu", "gpu"), and
37    'backend_args'/'backend_tags' specifies backend-specific args parameters to
38    use when generating the cc_test.
39
40    The name of the cc_tests are the provided name argument with the backend name
41    appended, and the cc_library target name is the provided name argument with
42    "_lib" appended. For example, if name parameter is "foo_test", then the cpu
43    test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
44
45    The cc_library target can be used to link with other plugins outside of
46    xla_test.
47
48    The build rule also defines a test suite ${name} which includes the tests for
49    each of the supported backends.
50
51    Each generated cc_test target has a tag indicating which backend the test is
52    for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
53    tags can be used to gather tests for a particular backend into a test_suite.
54
55    Examples:
56
57      # Generates the targets: foo_test_cpu and foo_test_gpu.
58      xla_test(
59          name = "foo_test",
60          srcs = ["foo_test.cc"],
61          backends = ["cpu", "gpu"],
62          deps = [...],
63      )
64
65      # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
66      # includes the additional arg "--special_cpu_flag".
67      xla_test(
68          name = "bar_test",
69          srcs = ["bar_test.cc"],
70          backends = ["cpu", "gpu"],
71          backend_args = {"cpu": ["--special_cpu_flag"]}
72          deps = [...],
73      )
74
75    The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
76    to the value 1 where ${BACKEND} is the uppercase name of the backend.
77
78    Args:
79      name: Name of the target.
80      srcs: Sources for the target.
81      deps: Dependencies of the target.
82      xla_test_library_deps: If set, the generated test targets will depend on the
83        respective cc_libraries generated by the xla_test_library rule.
84      backends: A list of backends to generate tests for. Supported values: "cpu",
85        "gpu". If this list is empty, the test will be generated for all supported
86        backends.
87      disabled_backends: A list of backends to NOT generate tests for.
88      args: Test arguments for the target.
89      tags: Tags for the target.
90      copts: Additional copts to pass to the build.
91      data: Additional data to pass to the build.
92      backend_tags: A dict mapping backend name to list of additional tags to
93        use for that target.
94      backend_args: A dict mapping backend name to list of additional args to
95        use for that target.
96      **kwargs: Additional keyword arguments to pass to native.cc_test.
97    """
98
99    # All of the backends in all_backends are real hardware.
100    _ignore = [real_hardware_only]
101
102    test_names = []
103    if not backends:
104        backends = all_backends
105
106    backends = [
107        backend
108        for backend in backends
109        if backend not in disabled_backends
110    ]
111
112    native.cc_library(
113        name = "%s_lib" % name,
114        srcs = srcs,
115        copts = copts,
116        testonly = True,
117        deps = deps,
118    )
119
120    for backend in backends:
121        test_name = "%s_%s" % (name, backend)
122        this_backend_tags = ["xla_%s" % backend]
123        this_backend_copts = []
124        this_backend_args = backend_args.get(backend, [])
125        this_backend_data = []
126        if backend == "cpu":
127            backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
128            backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
129        elif backend == "gpu":
130            backend_deps = if_gpu_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"])
131            backend_deps += if_gpu_is_configured(["//tensorflow/compiler/xla/tests:test_macros_gpu"])
132            this_backend_tags += tf_gpu_tests_tags()
133        elif backend in plugins:
134            backend_deps = []
135            backend_deps += plugins[backend]["deps"]
136            this_backend_copts += plugins[backend]["copts"]
137            this_backend_tags += plugins[backend]["tags"]
138            this_backend_args += plugins[backend]["args"]
139            this_backend_data += plugins[backend]["data"]
140        else:
141            fail("Unknown backend %s" % backend)
142
143        if xla_test_library_deps:
144            for lib_dep in xla_test_library_deps:
145                backend_deps += ["%s_%s" % (lib_dep, backend)]
146
147        tf_cc_test(
148            name = test_name,
149            srcs = srcs,
150            tags = tags + backend_tags.get(backend, []) + this_backend_tags,
151            extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
152                          this_backend_copts,
153            args = args + this_backend_args,
154            deps = deps + backend_deps,
155            data = data + this_backend_data,
156            **kwargs
157        )
158
159        test_names.append(test_name)
160
161    native.test_suite(name = name, tags = tags, tests = test_names)
162
163def xla_test_library(
164        name,
165        srcs,
166        hdrs = [],
167        deps = [],
168        backends = []):
169    """Generates cc_library targets for the given XLA backends.
170
171    This rule forces the sources to be compiled for each backend so that the
172    backend specific macros could expand correctly. It's useful when test targets
173    in different directories referring to the same sources but test with different
174    arguments.
175
176    Examples:
177
178      # Generates the targets: foo_test_library_cpu and foo_test_gpu.
179      xla_test_library(
180          name = "foo_test_library",
181          srcs = ["foo_test.cc"],
182          backends = ["cpu", "gpu"],
183          deps = [...],
184      )
185      # Then use the xla_test rule to generate test targets:
186      xla_test(
187          name = "foo_test",
188          srcs = [],
189          backends = ["cpu", "gpu"],
190          deps = [...],
191          xla_test_library_deps = [":foo_test_library"],
192      )
193
194    Args:
195      name: Name of the target.
196      srcs: Sources for the target.
197      hdrs: Headers for the target.
198      deps: Dependencies of the target.
199      backends: A list of backends to generate libraries for.
200        Supported values: "cpu", "gpu". If this list is empty, the
201        library will be generated for all supported backends.
202    """
203
204    if not backends:
205        backends = all_backends
206
207    for backend in backends:
208        this_backend_copts = []
209        if backend in ["cpu", "gpu"]:
210            backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
211        elif backend in plugins:
212            backend_deps = plugins[backend]["deps"]
213            this_backend_copts += plugins[backend]["copts"]
214        else:
215            fail("Unknown backend %s" % backend)
216
217        native.cc_library(
218            name = "%s_%s" % (name, backend),
219            srcs = srcs,
220            testonly = True,
221            hdrs = hdrs,
222            copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
223                    this_backend_copts,
224            deps = deps + backend_deps,
225        )
226
227def generate_backend_suites(backends = []):
228    if not backends:
229        backends = all_backends
230    for backend in backends:
231        native.test_suite(
232            name = "%s_tests" % backend,
233            tags = ["xla_%s" % backend, "-broken", "manual"],
234        )
235
236def generate_backend_test_macros(backends = []):
237    if not backends:
238        backends = all_backends
239    for backend in backends:
240        manifest = ""
241        if backend in plugins:
242            manifest = plugins[backend]["disabled_manifest"]
243
244        native.cc_library(
245            name = "test_macros_%s" % backend,
246            testonly = True,
247            srcs = ["test_macros.cc"],
248            hdrs = ["test_macros.h"],
249            copts = [
250                "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
251                "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
252            ],
253            deps = [
254                "//tensorflow/core/platform:logging",
255            ],
256        )
257