1# Common computation builders for XLA. 2 3load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") 4 5licenses(["notice"]) # Apache 2.0 6 7package(default_visibility = ["//tensorflow/compiler/xla/client:friends"]) 8 9# Filegroup used to collect source files for dependency checking. 10filegroup( 11 name = "c_srcs", 12 data = glob([ 13 "**/*.cc", 14 "**/*.h", 15 ]), 16) 17 18# Generate test_suites for all backends, named "${backend}_tests". 19generate_backend_suites() 20 21cc_library( 22 name = "arithmetic", 23 srcs = ["arithmetic.cc"], 24 hdrs = ["arithmetic.h"], 25 deps = [ 26 ":constants", 27 "//tensorflow/compiler/xla:shape_util", 28 "//tensorflow/compiler/xla:status_macros", 29 "//tensorflow/compiler/xla:types", 30 "//tensorflow/compiler/xla:xla_data_proto", 31 "//tensorflow/compiler/xla/client:xla_builder", 32 "//tensorflow/compiler/xla/client:xla_computation", 33 "@com_google_absl//absl/strings", 34 ], 35) 36 37xla_test( 38 name = "arithmetic_test", 39 srcs = ["arithmetic_test.cc"], 40 deps = [ 41 ":arithmetic", 42 "//tensorflow/compiler/xla:literal_util", 43 "//tensorflow/compiler/xla:test", 44 "//tensorflow/compiler/xla:types", 45 "//tensorflow/compiler/xla:xla_data_proto", 46 "//tensorflow/compiler/xla/client:xla_builder", 47 "//tensorflow/compiler/xla/tests:client_library_test_base", 48 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 49 ], 50) 51 52cc_library( 53 name = "comparators", 54 srcs = ["comparators.cc"], 55 hdrs = ["comparators.h"], 56 deps = [ 57 ":constants", 58 "//tensorflow/compiler/xla:shape_util", 59 "//tensorflow/compiler/xla:types", 60 "//tensorflow/compiler/xla:xla_data_proto", 61 "//tensorflow/compiler/xla/client:xla_builder", 62 "//tensorflow/compiler/xla/client:xla_computation", 63 "@com_google_absl//absl/strings", 64 "@com_google_absl//absl/types:span", 65 ], 66) 67 68xla_test( 69 name = "comparators_test", 70 srcs = ["comparators_test.cc"], 71 deps = [ 72 ":comparators", 73 ":constants", 74 "//tensorflow/compiler/xla:shape_util", 75 "//tensorflow/compiler/xla:test", 76 "//tensorflow/compiler/xla:types", 77 "//tensorflow/compiler/xla:xla_data_proto", 78 "//tensorflow/compiler/xla/client:xla_builder", 79 "//tensorflow/compiler/xla/tests:client_library_test_base", 80 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 81 "@com_google_absl//absl/container:inlined_vector", 82 ], 83) 84 85cc_library( 86 name = "constants", 87 srcs = ["constants.cc"], 88 hdrs = ["constants.h"], 89 deps = [ 90 "//tensorflow/compiler/xla:literal_util", 91 "//tensorflow/compiler/xla:shape_util", 92 "//tensorflow/compiler/xla:types", 93 "//tensorflow/compiler/xla:util", 94 "//tensorflow/compiler/xla:xla_data_proto", 95 "//tensorflow/compiler/xla/client:xla_builder", 96 ], 97) 98 99xla_test( 100 name = "constants_test", 101 srcs = ["constants_test.cc"], 102 deps = [ 103 ":constants", 104 "//tensorflow/compiler/xla:test", 105 "//tensorflow/compiler/xla:types", 106 "//tensorflow/compiler/xla:xla_data_proto", 107 "//tensorflow/compiler/xla/client:xla_builder", 108 "//tensorflow/compiler/xla/tests:client_library_test_base", 109 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 110 ], 111) 112 113cc_library( 114 name = "conv_grad_size_util", 115 srcs = ["conv_grad_size_util.cc"], 116 hdrs = ["conv_grad_size_util.h"], 117 deps = [ 118 "//tensorflow/compiler/xla:status_macros", 119 "//tensorflow/compiler/xla/client:padding", 120 "//tensorflow/core:lib", 121 ], 122) 123 124cc_library( 125 name = "loops", 126 srcs = ["loops.cc"], 127 hdrs = ["loops.h"], 128 deps = [ 129 ":constants", 130 "//tensorflow/compiler/xla:shape_util", 131 "//tensorflow/compiler/xla:status_macros", 132 "//tensorflow/compiler/xla:statusor", 133 "//tensorflow/compiler/xla/client:xla_builder", 134 "//tensorflow/compiler/xla/client:xla_computation", 135 "@com_google_absl//absl/strings", 136 "@com_google_absl//absl/types:span", 137 ], 138) 139 140cc_library( 141 name = "math", 142 srcs = ["math.cc"], 143 hdrs = ["math.h"], 144 deps = [ 145 ":arithmetic", 146 ":constants", 147 "//tensorflow/compiler/xla:shape_util", 148 "//tensorflow/compiler/xla:status_macros", 149 "//tensorflow/compiler/xla/client:xla_builder", 150 ], 151) 152 153xla_test( 154 name = "math_test", 155 srcs = ["math_test.cc"], 156 deps = [ 157 ":constants", 158 ":math", 159 "//tensorflow/compiler/xla:literal_util", 160 "//tensorflow/compiler/xla:shape_util", 161 "//tensorflow/compiler/xla:test", 162 "//tensorflow/compiler/xla:types", 163 "//tensorflow/compiler/xla:xla_data_proto", 164 "//tensorflow/compiler/xla/client:xla_builder", 165 "//tensorflow/compiler/xla/tests:client_library_test_base", 166 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 167 ], 168) 169 170cc_library( 171 name = "matrix", 172 srcs = ["matrix.cc"], 173 hdrs = ["matrix.h"], 174 deps = [ 175 ":arithmetic", 176 ":constants", 177 ":slicing", 178 "//tensorflow/compiler/xla:shape_util", 179 "//tensorflow/compiler/xla:status", 180 "//tensorflow/compiler/xla:status_macros", 181 "//tensorflow/compiler/xla:statusor", 182 "//tensorflow/compiler/xla:types", 183 "//tensorflow/compiler/xla:util", 184 "//tensorflow/compiler/xla:xla_data_proto", 185 "//tensorflow/compiler/xla/client:xla_builder", 186 "@com_google_absl//absl/algorithm:container", 187 "@com_google_absl//absl/container:flat_hash_set", 188 "@com_google_absl//absl/strings", 189 "@com_google_absl//absl/types:span", 190 ], 191) 192 193xla_test( 194 name = "matrix_test", 195 srcs = ["matrix_test.cc"], 196 deps = [ 197 ":matrix", 198 ":slicing", 199 "//tensorflow/compiler/xla:status", 200 "//tensorflow/compiler/xla:status_macros", 201 "//tensorflow/compiler/xla:statusor", 202 "//tensorflow/compiler/xla:test", 203 "//tensorflow/compiler/xla:types", 204 "//tensorflow/compiler/xla:xla_data_proto", 205 "//tensorflow/compiler/xla/client:xla_builder", 206 "//tensorflow/compiler/xla/tests:client_library_test_base", 207 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 208 "@com_google_absl//absl/strings", 209 ], 210) 211 212cc_library( 213 name = "pooling", 214 srcs = ["pooling.cc"], 215 hdrs = ["pooling.h"], 216 deps = [ 217 ":arithmetic", 218 ":constants", 219 ":conv_grad_size_util", 220 "//tensorflow/compiler/xla/client:xla_builder", 221 "@com_google_absl//absl/container:inlined_vector", 222 ], 223) 224 225xla_test( 226 name = "pooling_test", 227 srcs = ["pooling_test.cc"], 228 deps = [ 229 ":pooling", 230 "//tensorflow/compiler/xla:test", 231 "//tensorflow/compiler/xla/tests:client_library_test_base", 232 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 233 "@com_google_absl//absl/container:inlined_vector", 234 ], 235) 236 237cc_library( 238 name = "prng", 239 srcs = ["prng.cc"], 240 hdrs = ["prng.h"], 241 deps = [ 242 ":constants", 243 ":math", 244 "//tensorflow/compiler/xla:util", 245 "//tensorflow/compiler/xla:xla_data_proto", 246 "//tensorflow/compiler/xla/client:xla_builder", 247 "@com_google_absl//absl/base", 248 ], 249) 250 251cc_library( 252 name = "qr", 253 srcs = ["qr.cc"], 254 hdrs = ["qr.h"], 255 deps = [ 256 ":arithmetic", 257 ":constants", 258 ":loops", 259 ":math", 260 ":matrix", 261 ":slicing", 262 "//tensorflow/compiler/xla:literal_util", 263 "//tensorflow/compiler/xla:shape_util", 264 "//tensorflow/compiler/xla:status_macros", 265 "//tensorflow/compiler/xla:statusor", 266 "//tensorflow/compiler/xla:xla_data_proto", 267 "//tensorflow/compiler/xla/client:xla_builder", 268 "//tensorflow/core:lib", 269 ], 270) 271 272xla_test( 273 name = "qr_test", 274 srcs = ["qr_test.cc"], 275 tags = ["optonly"], 276 deps = [ 277 ":matrix", 278 ":qr", 279 "//tensorflow/compiler/xla:array2d", 280 "//tensorflow/compiler/xla:array3d", 281 "//tensorflow/compiler/xla:literal", 282 "//tensorflow/compiler/xla:statusor", 283 "//tensorflow/compiler/xla:test", 284 "//tensorflow/compiler/xla:xla_data_proto", 285 "//tensorflow/compiler/xla/client:xla_builder", 286 "//tensorflow/compiler/xla/tests:client_library_test_base", 287 "//tensorflow/compiler/xla/tests:literal_test_util", 288 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 289 "//tensorflow/core:test", 290 ], 291) 292 293cc_library( 294 name = "slicing", 295 srcs = ["slicing.cc"], 296 hdrs = ["slicing.h"], 297 deps = [ 298 "//tensorflow/compiler/xla:types", 299 "//tensorflow/compiler/xla/client:xla_builder", 300 "@com_google_absl//absl/types:span", 301 ], 302) 303 304xla_test( 305 name = "slicing_test", 306 srcs = ["slicing_test.cc"], 307 deps = [ 308 ":slicing", 309 "//tensorflow/compiler/xla:literal_util", 310 "//tensorflow/compiler/xla:test", 311 "//tensorflow/compiler/xla:types", 312 "//tensorflow/compiler/xla/client:xla_builder", 313 "//tensorflow/compiler/xla/tests:client_library_test_base", 314 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 315 ], 316) 317 318cc_library( 319 name = "sorting", 320 srcs = ["sorting.cc"], 321 hdrs = ["sorting.h"], 322 deps = [ 323 ":comparators", 324 "//tensorflow/compiler/xla:shape_util", 325 "//tensorflow/compiler/xla:types", 326 "//tensorflow/compiler/xla:util", 327 "//tensorflow/compiler/xla:xla_data_proto", 328 "//tensorflow/compiler/xla/client:xla_builder", 329 ], 330) 331 332xla_test( 333 name = "sorting_test", 334 srcs = ["sorting_test.cc"], 335 deps = [ 336 ":sorting", 337 "//tensorflow/compiler/xla:test", 338 "//tensorflow/compiler/xla:types", 339 "//tensorflow/compiler/xla/client:xla_builder", 340 "//tensorflow/compiler/xla/tests:client_library_test_base", 341 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 342 ], 343) 344 345cc_library( 346 name = "quantize", 347 hdrs = ["quantize.h"], 348 deps = [ 349 ":constants", 350 "//tensorflow/compiler/xla:types", 351 "//tensorflow/compiler/xla:util", 352 "//tensorflow/compiler/xla:xla_data_proto", 353 "//tensorflow/compiler/xla/client:xla_builder", 354 "//tensorflow/core:lib", 355 ], 356) 357 358xla_test( 359 name = "quantize_test", 360 srcs = ["quantize_test.cc"], 361 # TODO(b/122119490): re-enable TAP after fixing. 362 tags = [ 363 "notap", 364 ], 365 deps = [ 366 ":quantize", 367 "//tensorflow/compiler/xla:test", 368 "//tensorflow/compiler/xla:types", 369 "//tensorflow/compiler/xla:util", 370 "//tensorflow/compiler/xla/client:xla_builder", 371 "//tensorflow/compiler/xla/tests:client_library_test_base", 372 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 373 ], 374) 375 376cc_library( 377 name = "testing", 378 srcs = ["testing.cc"], 379 hdrs = ["testing.h"], 380 deps = [ 381 "//tensorflow/compiler/xla:execution_options_util", 382 "//tensorflow/compiler/xla:literal", 383 "//tensorflow/compiler/xla:shape_util", 384 "//tensorflow/compiler/xla:statusor", 385 "//tensorflow/compiler/xla:types", 386 "//tensorflow/compiler/xla:util", 387 "//tensorflow/compiler/xla:xla_data_proto", 388 "//tensorflow/compiler/xla/client", 389 "//tensorflow/compiler/xla/client:global_data", 390 "//tensorflow/compiler/xla/client:xla_builder", 391 "//tensorflow/compiler/xla/client:xla_computation", 392 "//tensorflow/compiler/xla/tests:test_utils", 393 "//tensorflow/core:lib", 394 "@com_google_absl//absl/strings", 395 ], 396) 397 398cc_library( 399 name = "self_adjoint_eig", 400 srcs = ["self_adjoint_eig.cc"], 401 hdrs = ["self_adjoint_eig.h"], 402 deps = [ 403 ":arithmetic", 404 ":comparators", 405 ":constants", 406 ":loops", 407 ":math", 408 ":matrix", 409 ":slicing", 410 "//tensorflow/compiler/xla:literal_util", 411 "//tensorflow/compiler/xla:shape_util", 412 "//tensorflow/compiler/xla:status_macros", 413 "//tensorflow/compiler/xla:statusor", 414 "//tensorflow/compiler/xla:xla_data_proto", 415 "//tensorflow/compiler/xla/client:xla_builder", 416 "//tensorflow/core:lib", 417 ], 418) 419 420xla_test( 421 name = "self_adjoint_eig_test", 422 srcs = ["self_adjoint_eig_test.cc"], 423 blacklisted_backends = [ 424 "cpu", 425 "gpu", 426 ], 427 real_hardware_only = True, 428 shard_count = 10, 429 tags = ["optonly"], 430 deps = [ 431 ":arithmetic", 432 ":constants", 433 ":matrix", 434 ":self_adjoint_eig", 435 "//tensorflow/compiler/xla:array2d", 436 "//tensorflow/compiler/xla:array3d", 437 "//tensorflow/compiler/xla:literal", 438 "//tensorflow/compiler/xla:statusor", 439 "//tensorflow/compiler/xla:test", 440 "//tensorflow/compiler/xla:xla_data_proto", 441 "//tensorflow/compiler/xla/client:xla_builder", 442 "//tensorflow/compiler/xla/tests:client_library_test_base", 443 "//tensorflow/compiler/xla/tests:literal_test_util", 444 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 445 "//tensorflow/core:test", 446 ], 447) 448 449cc_library( 450 name = "svd", 451 srcs = ["svd.cc"], 452 hdrs = ["svd.h"], 453 deps = [ 454 ":arithmetic", 455 ":comparators", 456 ":constants", 457 ":loops", 458 ":math", 459 ":matrix", 460 ":slicing", 461 "//tensorflow/compiler/xla:literal_util", 462 "//tensorflow/compiler/xla:shape_util", 463 "//tensorflow/compiler/xla:status_macros", 464 "//tensorflow/compiler/xla:statusor", 465 "//tensorflow/compiler/xla:xla_data_proto", 466 "//tensorflow/compiler/xla/client:xla_builder", 467 "//tensorflow/core:lib", 468 ], 469) 470 471xla_test( 472 name = "svd_test", 473 srcs = ["svd_test.cc"], 474 blacklisted_backends = [ 475 "cpu", 476 "gpu", 477 ], 478 real_hardware_only = True, 479 shard_count = 10, 480 tags = ["optonly"], 481 deps = [ 482 ":arithmetic", 483 ":constants", 484 ":matrix", 485 ":slicing", 486 ":svd", 487 "//tensorflow/compiler/xla:array2d", 488 "//tensorflow/compiler/xla:array3d", 489 "//tensorflow/compiler/xla:literal", 490 "//tensorflow/compiler/xla:shape_util", 491 "//tensorflow/compiler/xla:statusor", 492 "//tensorflow/compiler/xla:test", 493 "//tensorflow/compiler/xla:xla_data_proto", 494 "//tensorflow/compiler/xla/client:xla_builder", 495 "//tensorflow/compiler/xla/tests:client_library_test_base", 496 "//tensorflow/compiler/xla/tests:literal_test_util", 497 "//tensorflow/compiler/xla/tests:xla_internal_test_main", 498 "//tensorflow/core:test", 499 ], 500) 501