• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Common computation builders for XLA.
2
3load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
4
5licenses(["notice"])  # Apache 2.0
6
7package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
8
9# Filegroup used to collect source files for dependency checking.
10filegroup(
11    name = "c_srcs",
12    data = glob([
13        "**/*.cc",
14        "**/*.h",
15    ]),
16)
17
18# Generate test_suites for all backends, named "${backend}_tests".
19generate_backend_suites()
20
21cc_library(
22    name = "arithmetic",
23    srcs = ["arithmetic.cc"],
24    hdrs = ["arithmetic.h"],
25    deps = [
26        ":constants",
27        "//tensorflow/compiler/xla:shape_util",
28        "//tensorflow/compiler/xla:status_macros",
29        "//tensorflow/compiler/xla:types",
30        "//tensorflow/compiler/xla:xla_data_proto",
31        "//tensorflow/compiler/xla/client:xla_builder",
32        "//tensorflow/compiler/xla/client:xla_computation",
33        "@com_google_absl//absl/strings",
34    ],
35)
36
37xla_test(
38    name = "arithmetic_test",
39    srcs = ["arithmetic_test.cc"],
40    deps = [
41        ":arithmetic",
42        "//tensorflow/compiler/xla:literal_util",
43        "//tensorflow/compiler/xla:test",
44        "//tensorflow/compiler/xla:types",
45        "//tensorflow/compiler/xla:xla_data_proto",
46        "//tensorflow/compiler/xla/client:xla_builder",
47        "//tensorflow/compiler/xla/tests:client_library_test_base",
48        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
49    ],
50)
51
52cc_library(
53    name = "comparators",
54    srcs = ["comparators.cc"],
55    hdrs = ["comparators.h"],
56    deps = [
57        ":constants",
58        "//tensorflow/compiler/xla:shape_util",
59        "//tensorflow/compiler/xla:types",
60        "//tensorflow/compiler/xla:xla_data_proto",
61        "//tensorflow/compiler/xla/client:xla_builder",
62        "//tensorflow/compiler/xla/client:xla_computation",
63        "@com_google_absl//absl/strings",
64        "@com_google_absl//absl/types:span",
65    ],
66)
67
68xla_test(
69    name = "comparators_test",
70    srcs = ["comparators_test.cc"],
71    deps = [
72        ":comparators",
73        ":constants",
74        "//tensorflow/compiler/xla:shape_util",
75        "//tensorflow/compiler/xla:test",
76        "//tensorflow/compiler/xla:types",
77        "//tensorflow/compiler/xla:xla_data_proto",
78        "//tensorflow/compiler/xla/client:xla_builder",
79        "//tensorflow/compiler/xla/tests:client_library_test_base",
80        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
81        "@com_google_absl//absl/container:inlined_vector",
82    ],
83)
84
85cc_library(
86    name = "constants",
87    srcs = ["constants.cc"],
88    hdrs = ["constants.h"],
89    deps = [
90        "//tensorflow/compiler/xla:literal_util",
91        "//tensorflow/compiler/xla:shape_util",
92        "//tensorflow/compiler/xla:types",
93        "//tensorflow/compiler/xla:util",
94        "//tensorflow/compiler/xla:xla_data_proto",
95        "//tensorflow/compiler/xla/client:xla_builder",
96    ],
97)
98
99xla_test(
100    name = "constants_test",
101    srcs = ["constants_test.cc"],
102    deps = [
103        ":constants",
104        "//tensorflow/compiler/xla:test",
105        "//tensorflow/compiler/xla:types",
106        "//tensorflow/compiler/xla:xla_data_proto",
107        "//tensorflow/compiler/xla/client:xla_builder",
108        "//tensorflow/compiler/xla/tests:client_library_test_base",
109        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
110    ],
111)
112
113cc_library(
114    name = "conv_grad_size_util",
115    srcs = ["conv_grad_size_util.cc"],
116    hdrs = ["conv_grad_size_util.h"],
117    deps = [
118        "//tensorflow/compiler/xla:status_macros",
119        "//tensorflow/compiler/xla/client:padding",
120        "//tensorflow/core:lib",
121    ],
122)
123
124cc_library(
125    name = "loops",
126    srcs = ["loops.cc"],
127    hdrs = ["loops.h"],
128    deps = [
129        ":constants",
130        "//tensorflow/compiler/xla:shape_util",
131        "//tensorflow/compiler/xla:status_macros",
132        "//tensorflow/compiler/xla:statusor",
133        "//tensorflow/compiler/xla/client:xla_builder",
134        "//tensorflow/compiler/xla/client:xla_computation",
135        "@com_google_absl//absl/strings",
136        "@com_google_absl//absl/types:span",
137    ],
138)
139
140cc_library(
141    name = "math",
142    srcs = ["math.cc"],
143    hdrs = ["math.h"],
144    deps = [
145        ":arithmetic",
146        ":constants",
147        "//tensorflow/compiler/xla:shape_util",
148        "//tensorflow/compiler/xla:status_macros",
149        "//tensorflow/compiler/xla/client:xla_builder",
150    ],
151)
152
153xla_test(
154    name = "math_test",
155    srcs = ["math_test.cc"],
156    deps = [
157        ":constants",
158        ":math",
159        "//tensorflow/compiler/xla:literal_util",
160        "//tensorflow/compiler/xla:shape_util",
161        "//tensorflow/compiler/xla:test",
162        "//tensorflow/compiler/xla:types",
163        "//tensorflow/compiler/xla:xla_data_proto",
164        "//tensorflow/compiler/xla/client:xla_builder",
165        "//tensorflow/compiler/xla/tests:client_library_test_base",
166        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
167    ],
168)
169
170cc_library(
171    name = "matrix",
172    srcs = ["matrix.cc"],
173    hdrs = ["matrix.h"],
174    deps = [
175        ":arithmetic",
176        ":constants",
177        ":slicing",
178        "//tensorflow/compiler/xla:shape_util",
179        "//tensorflow/compiler/xla:status",
180        "//tensorflow/compiler/xla:status_macros",
181        "//tensorflow/compiler/xla:statusor",
182        "//tensorflow/compiler/xla:types",
183        "//tensorflow/compiler/xla:util",
184        "//tensorflow/compiler/xla:xla_data_proto",
185        "//tensorflow/compiler/xla/client:xla_builder",
186        "@com_google_absl//absl/algorithm:container",
187        "@com_google_absl//absl/container:flat_hash_set",
188        "@com_google_absl//absl/strings",
189        "@com_google_absl//absl/types:span",
190    ],
191)
192
193xla_test(
194    name = "matrix_test",
195    srcs = ["matrix_test.cc"],
196    deps = [
197        ":matrix",
198        ":slicing",
199        "//tensorflow/compiler/xla:status",
200        "//tensorflow/compiler/xla:status_macros",
201        "//tensorflow/compiler/xla:statusor",
202        "//tensorflow/compiler/xla:test",
203        "//tensorflow/compiler/xla:types",
204        "//tensorflow/compiler/xla:xla_data_proto",
205        "//tensorflow/compiler/xla/client:xla_builder",
206        "//tensorflow/compiler/xla/tests:client_library_test_base",
207        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
208        "@com_google_absl//absl/strings",
209    ],
210)
211
212cc_library(
213    name = "pooling",
214    srcs = ["pooling.cc"],
215    hdrs = ["pooling.h"],
216    deps = [
217        ":arithmetic",
218        ":constants",
219        ":conv_grad_size_util",
220        "//tensorflow/compiler/xla/client:xla_builder",
221        "@com_google_absl//absl/container:inlined_vector",
222    ],
223)
224
225xla_test(
226    name = "pooling_test",
227    srcs = ["pooling_test.cc"],
228    deps = [
229        ":pooling",
230        "//tensorflow/compiler/xla:test",
231        "//tensorflow/compiler/xla/tests:client_library_test_base",
232        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
233        "@com_google_absl//absl/container:inlined_vector",
234    ],
235)
236
237cc_library(
238    name = "prng",
239    srcs = ["prng.cc"],
240    hdrs = ["prng.h"],
241    deps = [
242        ":constants",
243        ":math",
244        "//tensorflow/compiler/xla:util",
245        "//tensorflow/compiler/xla:xla_data_proto",
246        "//tensorflow/compiler/xla/client:xla_builder",
247        "@com_google_absl//absl/base",
248    ],
249)
250
251cc_library(
252    name = "qr",
253    srcs = ["qr.cc"],
254    hdrs = ["qr.h"],
255    deps = [
256        ":arithmetic",
257        ":constants",
258        ":loops",
259        ":math",
260        ":matrix",
261        ":slicing",
262        "//tensorflow/compiler/xla:literal_util",
263        "//tensorflow/compiler/xla:shape_util",
264        "//tensorflow/compiler/xla:status_macros",
265        "//tensorflow/compiler/xla:statusor",
266        "//tensorflow/compiler/xla:xla_data_proto",
267        "//tensorflow/compiler/xla/client:xla_builder",
268        "//tensorflow/core:lib",
269    ],
270)
271
272xla_test(
273    name = "qr_test",
274    srcs = ["qr_test.cc"],
275    tags = ["optonly"],
276    deps = [
277        ":matrix",
278        ":qr",
279        "//tensorflow/compiler/xla:array2d",
280        "//tensorflow/compiler/xla:array3d",
281        "//tensorflow/compiler/xla:literal",
282        "//tensorflow/compiler/xla:statusor",
283        "//tensorflow/compiler/xla:test",
284        "//tensorflow/compiler/xla:xla_data_proto",
285        "//tensorflow/compiler/xla/client:xla_builder",
286        "//tensorflow/compiler/xla/tests:client_library_test_base",
287        "//tensorflow/compiler/xla/tests:literal_test_util",
288        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
289        "//tensorflow/core:test",
290    ],
291)
292
293cc_library(
294    name = "slicing",
295    srcs = ["slicing.cc"],
296    hdrs = ["slicing.h"],
297    deps = [
298        "//tensorflow/compiler/xla:types",
299        "//tensorflow/compiler/xla/client:xla_builder",
300        "@com_google_absl//absl/types:span",
301    ],
302)
303
304xla_test(
305    name = "slicing_test",
306    srcs = ["slicing_test.cc"],
307    deps = [
308        ":slicing",
309        "//tensorflow/compiler/xla:literal_util",
310        "//tensorflow/compiler/xla:test",
311        "//tensorflow/compiler/xla:types",
312        "//tensorflow/compiler/xla/client:xla_builder",
313        "//tensorflow/compiler/xla/tests:client_library_test_base",
314        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
315    ],
316)
317
318cc_library(
319    name = "sorting",
320    srcs = ["sorting.cc"],
321    hdrs = ["sorting.h"],
322    deps = [
323        ":comparators",
324        "//tensorflow/compiler/xla:shape_util",
325        "//tensorflow/compiler/xla:types",
326        "//tensorflow/compiler/xla:util",
327        "//tensorflow/compiler/xla:xla_data_proto",
328        "//tensorflow/compiler/xla/client:xla_builder",
329    ],
330)
331
332xla_test(
333    name = "sorting_test",
334    srcs = ["sorting_test.cc"],
335    deps = [
336        ":sorting",
337        "//tensorflow/compiler/xla:test",
338        "//tensorflow/compiler/xla:types",
339        "//tensorflow/compiler/xla/client:xla_builder",
340        "//tensorflow/compiler/xla/tests:client_library_test_base",
341        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
342    ],
343)
344
345cc_library(
346    name = "quantize",
347    hdrs = ["quantize.h"],
348    deps = [
349        ":constants",
350        "//tensorflow/compiler/xla:types",
351        "//tensorflow/compiler/xla:util",
352        "//tensorflow/compiler/xla:xla_data_proto",
353        "//tensorflow/compiler/xla/client:xla_builder",
354        "//tensorflow/core:lib",
355    ],
356)
357
358xla_test(
359    name = "quantize_test",
360    srcs = ["quantize_test.cc"],
361    # TODO(b/122119490): re-enable TAP after fixing.
362    tags = [
363        "notap",
364    ],
365    deps = [
366        ":quantize",
367        "//tensorflow/compiler/xla:test",
368        "//tensorflow/compiler/xla:types",
369        "//tensorflow/compiler/xla:util",
370        "//tensorflow/compiler/xla/client:xla_builder",
371        "//tensorflow/compiler/xla/tests:client_library_test_base",
372        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
373    ],
374)
375
376cc_library(
377    name = "testing",
378    srcs = ["testing.cc"],
379    hdrs = ["testing.h"],
380    deps = [
381        "//tensorflow/compiler/xla:execution_options_util",
382        "//tensorflow/compiler/xla:literal",
383        "//tensorflow/compiler/xla:shape_util",
384        "//tensorflow/compiler/xla:statusor",
385        "//tensorflow/compiler/xla:types",
386        "//tensorflow/compiler/xla:util",
387        "//tensorflow/compiler/xla:xla_data_proto",
388        "//tensorflow/compiler/xla/client",
389        "//tensorflow/compiler/xla/client:global_data",
390        "//tensorflow/compiler/xla/client:xla_builder",
391        "//tensorflow/compiler/xla/client:xla_computation",
392        "//tensorflow/compiler/xla/tests:test_utils",
393        "//tensorflow/core:lib",
394        "@com_google_absl//absl/strings",
395    ],
396)
397
398cc_library(
399    name = "self_adjoint_eig",
400    srcs = ["self_adjoint_eig.cc"],
401    hdrs = ["self_adjoint_eig.h"],
402    deps = [
403        ":arithmetic",
404        ":comparators",
405        ":constants",
406        ":loops",
407        ":math",
408        ":matrix",
409        ":slicing",
410        "//tensorflow/compiler/xla:literal_util",
411        "//tensorflow/compiler/xla:shape_util",
412        "//tensorflow/compiler/xla:status_macros",
413        "//tensorflow/compiler/xla:statusor",
414        "//tensorflow/compiler/xla:xla_data_proto",
415        "//tensorflow/compiler/xla/client:xla_builder",
416        "//tensorflow/core:lib",
417    ],
418)
419
420xla_test(
421    name = "self_adjoint_eig_test",
422    srcs = ["self_adjoint_eig_test.cc"],
423    blacklisted_backends = [
424        "cpu",
425        "gpu",
426    ],
427    real_hardware_only = True,
428    shard_count = 10,
429    tags = ["optonly"],
430    deps = [
431        ":arithmetic",
432        ":constants",
433        ":matrix",
434        ":self_adjoint_eig",
435        "//tensorflow/compiler/xla:array2d",
436        "//tensorflow/compiler/xla:array3d",
437        "//tensorflow/compiler/xla:literal",
438        "//tensorflow/compiler/xla:statusor",
439        "//tensorflow/compiler/xla:test",
440        "//tensorflow/compiler/xla:xla_data_proto",
441        "//tensorflow/compiler/xla/client:xla_builder",
442        "//tensorflow/compiler/xla/tests:client_library_test_base",
443        "//tensorflow/compiler/xla/tests:literal_test_util",
444        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
445        "//tensorflow/core:test",
446    ],
447)
448
449cc_library(
450    name = "svd",
451    srcs = ["svd.cc"],
452    hdrs = ["svd.h"],
453    deps = [
454        ":arithmetic",
455        ":comparators",
456        ":constants",
457        ":loops",
458        ":math",
459        ":matrix",
460        ":slicing",
461        "//tensorflow/compiler/xla:literal_util",
462        "//tensorflow/compiler/xla:shape_util",
463        "//tensorflow/compiler/xla:status_macros",
464        "//tensorflow/compiler/xla:statusor",
465        "//tensorflow/compiler/xla:xla_data_proto",
466        "//tensorflow/compiler/xla/client:xla_builder",
467        "//tensorflow/core:lib",
468    ],
469)
470
471xla_test(
472    name = "svd_test",
473    srcs = ["svd_test.cc"],
474    blacklisted_backends = [
475        "cpu",
476        "gpu",
477    ],
478    real_hardware_only = True,
479    shard_count = 10,
480    tags = ["optonly"],
481    deps = [
482        ":arithmetic",
483        ":constants",
484        ":matrix",
485        ":slicing",
486        ":svd",
487        "//tensorflow/compiler/xla:array2d",
488        "//tensorflow/compiler/xla:array3d",
489        "//tensorflow/compiler/xla:literal",
490        "//tensorflow/compiler/xla:shape_util",
491        "//tensorflow/compiler/xla:statusor",
492        "//tensorflow/compiler/xla:test",
493        "//tensorflow/compiler/xla:xla_data_proto",
494        "//tensorflow/compiler/xla/client:xla_builder",
495        "//tensorflow/compiler/xla/tests:client_library_test_base",
496        "//tensorflow/compiler/xla/tests:literal_test_util",
497        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
498        "//tensorflow/core:test",
499    ],
500)
501