# Common computation builders for XLA. load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") package( default_visibility = ["//tensorflow/compiler/xla/client:friends"], licenses = ["notice"], ) # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", data = glob([ "**/*.cc", "**/*.h", ]), ) # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], deps = [ ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "@com_google_absl//absl/strings", ], ) xla_test( name = "arithmetic_test", srcs = ["arithmetic_test.cc"], deps = [ ":arithmetic", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "comparators", srcs = ["comparators.cc"], hdrs = [ "comparators.h", ], deps = [ ":constants", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) xla_test( name = "comparators_test", srcs = ["comparators_test.cc"], deps = [ ":comparators", ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:inlined_vector", ], ) cc_library( name = "constants", srcs = ["constants.cc"], hdrs = ["constants.h"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", ], ) xla_test( name = "constants_test", srcs = ["constants_test.cc"], deps = [ ":constants", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "conv_grad_size_util", srcs = ["conv_grad_size_util.cc"], hdrs = ["conv_grad_size_util.h"], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:padding", "//tensorflow/core:lib", ], ) cc_library( name = "dynamic_shaped_ops", srcs = ["dynamic_shaped_ops.cc"], hdrs = ["dynamic_shaped_ops.h"], deps = [ ":constants", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:value_inference", "//tensorflow/compiler/xla/client:xla_builder", ], ) cc_library( name = "loops", srcs = ["loops.cc"], hdrs = ["loops.h"], deps = [ ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) cc_library( name = "math", srcs = ["math.cc"], hdrs = ["math.h"], deps = [ ":arithmetic", ":constants", ":loops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:xla_builder", ], ) xla_test( name = "math_test", srcs = ["math_test.cc"], deps = [ ":constants", ":math", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", ":slicing", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) xla_test( name = "matrix_test", srcs = ["matrix_test.cc"], deps = [ ":constants", ":matrix", ":slicing", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", ], ) cc_library( name = "pooling", srcs = ["pooling.cc"], hdrs = ["pooling.h"], deps = [ ":arithmetic", ":constants", ":conv_grad_size_util", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/container:inlined_vector", ], ) xla_test( name = "pooling_test", srcs = ["pooling_test.cc"], deps = [ ":pooling", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:inlined_vector", ], ) cc_library( name = "prng", srcs = ["prng.cc"], hdrs = ["prng.h"], deps = [ ":constants", ":math", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/base", ], ) cc_library( name = "qr", srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ ":arithmetic", ":constants", ":loops", ":math", ":matrix", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) xla_test( name = "qr_test", srcs = ["qr_test.cc"], tags = ["optonly"], deps = [ ":constants", ":matrix", ":qr", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "//tensorflow/core/platform:tensor_float_32_utils", ], ) cc_library( name = "lu_decomposition", srcs = ["lu_decomposition.cc"], hdrs = ["lu_decomposition.h"], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", ], ) cc_library( name = "slicing", srcs = ["slicing.cc"], hdrs = ["slicing.h"], deps = [ ":arithmetic", ":constants", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], ) xla_test( name = "slicing_test", srcs = ["slicing_test.cc"], deps = [ ":slicing", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "sorting", srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ ":comparators", ":constants", ":loops", ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", ], ) xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], deps = [ ":sorting", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "quantize", hdrs = ["quantize.h"], deps = [ ":constants", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) xla_test( name = "quantize_test", srcs = ["quantize_test.cc"], # TODO(b/122119490): re-enable TAP after fixing. tags = [ "manual", "notap", ], deps = [ ":quantize", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "@com_google_absl//absl/strings", ], ) cc_library( name = "self_adjoint_eig", srcs = ["self_adjoint_eig.cc"], hdrs = ["self_adjoint_eig.h"], deps = [ ":arithmetic", ":comparators", ":constants", ":loops", ":math", ":matrix", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) xla_test( name = "self_adjoint_eig_test", srcs = ["self_adjoint_eig_test.cc"], real_hardware_only = True, shard_count = 5, tags = ["optonly"], deps = [ ":arithmetic", ":constants", ":math", ":matrix", ":self_adjoint_eig", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) cc_library( name = "svd", srcs = ["svd.cc"], hdrs = ["svd.h"], deps = [ ":arithmetic", ":comparators", ":constants", ":loops", ":math", ":matrix", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) xla_test( name = "svd_test", srcs = ["svd_test.cc"], real_hardware_only = True, shard_count = 10, tags = ["optonly"], deps = [ ":arithmetic", ":constants", ":matrix", ":slicing", ":svd", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "//tensorflow/core/platform:tensor_float_32_utils", ], ) cc_library( name = "tridiagonal", srcs = ["tridiagonal.cc"], hdrs = ["tridiagonal.h"], deps = [ ":constants", ":loops", ":slicing", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", ], ) xla_test( name = "tridiagonal_test", srcs = ["tridiagonal_test.cc"], real_hardware_only = True, shard_count = 10, tags = ["optonly"], deps = [ ":constants", ":slicing", ":tridiagonal", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:error_spec", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/stream_executor/lib", ], ) cc_library( name = "logdet", srcs = ["logdet.cc"], hdrs = ["logdet.h"], deps = [ ":arithmetic", ":constants", ":loops", ":math", ":matrix", ":qr", ":slicing", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", ], ) xla_test( name = "logdet_test", srcs = ["logdet_test.cc"], tags = [ "optonly", ], deps = [ ":logdet", ":matrix", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], )