1load("//tensorflow/tsl/platform/default:distribute.bzl", "distribute_py_test") 2load("//tensorflow:tensorflow.bzl", "cuda_py_test") 3load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") 4 5package(licenses = ["notice"]) 6 7distribute_py_test( 8 name = "saved_model_test", 9 srcs = ["saved_model_test.py"], 10 tags = [ 11 "no_windows", # TODO(b/171350360) 12 "nomultivm", # multi_worker_test_base incompatible with multivm base 13 "notsan", # b/195248428 14 ], 15 deps = [ 16 "//tensorflow:tensorflow_py", 17 "//tensorflow/python:lookup_ops", 18 "//tensorflow/python/distribute:combinations", 19 "//tensorflow/python/distribute:multi_worker_test_base", 20 "//tensorflow/python/distribute:parameter_server_strategy_v2", 21 "//tensorflow/python/distribute:sharded_variable", 22 "//tensorflow/python/distribute:strategy_combinations", 23 "//tensorflow/python/distribute:test_util", 24 "//tensorflow/python/distribute:values", 25 "//tensorflow/python/eager:context", 26 "//tensorflow/python/eager:test", 27 "@absl_py//absl/testing:parameterized", 28 ], 29) 30 31cuda_py_test( 32 name = "mwms_peer_failure_test", 33 size = "medium", 34 srcs = ["mwms_peer_failure_test.py"], 35 python_version = "PY3", 36 shard_count = 2, 37 tags = [ 38 "multi_and_single_gpu", 39 "no_oss", # TODO(b/227372713) 40 "notsan", # b/195248428 41 ], 42 deps = [ 43 "//tensorflow:tensorflow_py", 44 "//tensorflow/python/distribute:collective_all_reduce_strategy", 45 "//tensorflow/python/distribute:combinations", 46 "//tensorflow/python/distribute:multi_process_runner", 47 "//tensorflow/python/distribute:multi_worker_test_base", 48 "//tensorflow/python/distribute:test_util", 49 "//tensorflow/python/eager:test", 50 ], 51) 52 53py_library( 54 name = "mwms_peer_failure_test_lib", 55 srcs = ["mwms_peer_failure_test.py"], 56 visibility = ["//learning/brain/runtime/python:__pkg__"], 57 deps = [ 58 "//tensorflow:tensorflow_py", 59 "//tensorflow/python/distribute:collective_all_reduce_strategy", 60 "//tensorflow/python/distribute:combinations", 61 "//tensorflow/python/distribute:multi_process_runner", 62 "//tensorflow/python/distribute:multi_worker_test_base", 63 "//tensorflow/python/distribute:test_util", 64 "//tensorflow/python/eager:test", 65 ], 66) 67 68tpu_py_test( 69 name = "tpu_memory_test", 70 size = "medium", 71 srcs = ["tpu_memory_test.py"], 72 disable_experimental = True, 73 disable_mlir_bridge = True, 74 disable_tfrt = False, 75 disable_v2 = True, 76 python_version = "PY3", 77 tags = ["no_oss"], 78 deps = [ 79 "//tensorflow:tensorflow_py", 80 "//tensorflow/python/distribute:tpu_strategy", 81 "//tensorflow/python/eager:context", 82 "//third_party/py/numpy", 83 ], 84) 85