1# Description: 2# Contains the Keras Utilities (internal TensorFlow version). 3 4load("//tensorflow:tensorflow.bzl", "tf_py_test") 5load("//tensorflow:tensorflow.bzl", "cuda_py_test") 6 7package( 8 # TODO(scottzhu): Remove non-keras deps from TF. 9 default_visibility = [ 10 "//tensorflow/python/feature_column:__pkg__", 11 "//tensorflow/python/keras:__subpackages__", 12 "//tensorflow/tools/pip_package:__pkg__", 13 ], 14 licenses = ["notice"], # Apache 2.0 15) 16 17filegroup( 18 name = "all_py_srcs", 19 srcs = glob(["*.py"]), 20 visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"], 21) 22 23py_library( 24 name = "utils", 25 srcs = [ 26 "__init__.py", 27 ], 28 srcs_version = "PY3", 29 deps = [ 30 ":all_utils", 31 ], 32) 33 34py_library( 35 name = "all_utils", 36 srcs = [ 37 "all_utils.py", 38 ], 39 srcs_version = "PY3", 40 deps = [ 41 ":control_flow_util", 42 ":engine_utils", 43 ":generic_utils", 44 ":layer_utils", 45 ":multi_gpu_utils", 46 ":np_utils", 47 ":vis_utils", 48 ], 49) 50 51py_library( 52 name = "control_flow_util", 53 srcs = ["control_flow_util.py"], 54 srcs_version = "PY3", 55 deps = [], 56) 57 58py_library( 59 name = "kpl_test_utils", 60 srcs = ["kpl_test_utils.py"], 61 srcs_version = "PY3", 62 deps = [ 63 "//tensorflow/python/keras", 64 "//tensorflow/python/keras/layers/preprocessing:string_lookup", 65 ], 66) 67 68py_library( 69 name = "data_utils", 70 srcs = ["data_utils.py"], 71 srcs_version = "PY3", 72 deps = [ 73 ":generic_utils", 74 ":io_utils", 75 ":tf_inspect", 76 ], 77) 78 79py_library( 80 name = "engine_utils", 81 srcs = [ 82 "conv_utils.py", 83 "losses_utils.py", 84 ], 85 srcs_version = "PY3", 86 deps = [ 87 ":data_utils", 88 ":io_utils", 89 "//tensorflow/python/keras:backend", 90 ], 91) 92 93py_library( 94 name = "io_utils", 95 srcs = ["io_utils.py"], 96 srcs_version = "PY3", 97 deps = [ 98 "@six_archive//:six", 99 ], 100) 101 102py_library( 103 name = "tf_utils", 104 srcs = ["tf_utils.py"], 105 srcs_version = "PY3", 106 deps = [ 107 ":object_identity", 108 "//tensorflow/python:composite_tensor", 109 "//tensorflow/python:control_flow_ops", 110 "//tensorflow/python:framework_ops", 111 "//tensorflow/python:smart_cond", 112 "//tensorflow/python:tensor_shape", 113 "//tensorflow/python:tensor_util", 114 "//tensorflow/python:util", 115 "//tensorflow/python:variables", 116 "//tensorflow/python/eager:context", 117 "@six_archive//:six", 118 ], 119) 120 121py_library( 122 name = "generic_utils", 123 srcs = [ 124 "generic_utils.py", 125 ], 126 srcs_version = "PY3", 127 deps = [ 128 ":tf_contextlib", 129 ":tf_inspect", 130 "//tensorflow/python:util", 131 "//third_party/py/numpy", 132 ], 133) 134 135py_library( 136 name = "mode_keys", 137 srcs = [ 138 "mode_keys.py", 139 ], 140 srcs_version = "PY3", 141 deps = [ 142 "//tensorflow/python/saved_model/model_utils:mode_keys", 143 ], 144) 145 146py_library( 147 name = "layer_utils", 148 srcs = [ 149 "kernelized_utils.py", 150 "layer_utils.py", 151 ], 152 srcs_version = "PY3", 153 deps = [ 154 ":engine_utils", 155 "//tensorflow/python:util", 156 "//tensorflow/python/keras:backend", 157 "//third_party/py/numpy", 158 ], 159) 160 161py_library( 162 name = "metrics_utils", 163 srcs = [ 164 "metrics_utils.py", 165 ], 166 srcs_version = "PY3", 167 deps = [ 168 ":generic_utils", 169 ":tf_utils", 170 "//tensorflow/python:array_ops", 171 "//tensorflow/python:check_ops", 172 "//tensorflow/python:control_flow_ops", 173 "//tensorflow/python:distribute", 174 "//tensorflow/python:dtypes", 175 "//tensorflow/python:framework", 176 "//tensorflow/python:math_ops", 177 "//tensorflow/python:nn_ops", 178 "//tensorflow/python:util", 179 "//tensorflow/python:weights_broadcast_ops", 180 "//tensorflow/python/ops/losses", 181 "//tensorflow/python/ops/ragged:ragged_tensor", 182 "//tensorflow/python/ops/ragged:ragged_util", 183 "//tensorflow/python/tpu:tpu_lib", 184 ], 185) 186 187py_library( 188 name = "version_utils", 189 srcs = [ 190 "version_utils.py", 191 ], 192 srcs_version = "PY3", 193 deps = [ 194 "//tensorflow/python:framework_ops", 195 "//tensorflow/python:util", 196 ], 197) 198 199py_library( 200 name = "multi_gpu_utils", 201 srcs = [ 202 "multi_gpu_utils.py", 203 ], 204 srcs_version = "PY3", 205 deps = [ 206 "//tensorflow/python:array_ops", 207 "//tensorflow/python:framework_ops", 208 "//tensorflow/python:util", 209 "//tensorflow/python/keras:backend", 210 "//tensorflow/python/keras/layers", 211 ], 212) 213 214py_library( 215 name = "np_utils", 216 srcs = [ 217 "np_utils.py", 218 ], 219 srcs_version = "PY3", 220 deps = [ 221 "//tensorflow/python:util", 222 "//third_party/py/numpy", 223 ], 224) 225 226py_library( 227 name = "object_identity", 228 srcs = ["object_identity.py"], 229 srcs_version = "PY3", 230 deps = [], 231) 232 233py_library( 234 name = "tf_contextlib", 235 srcs = ["tf_contextlib.py"], 236 srcs_version = "PY3", 237 deps = [ 238 "//tensorflow/python:util", 239 ], 240) 241 242py_library( 243 name = "tf_inspect", 244 srcs = ["tf_inspect.py"], 245 srcs_version = "PY3", 246 deps = [ 247 "//tensorflow/python:util", 248 ], 249) 250 251py_library( 252 name = "vis_utils", 253 srcs = [ 254 "vis_utils.py", 255 ], 256 srcs_version = "PY3", 257 deps = [ 258 "//tensorflow/python:util", 259 ], 260) 261 262py_library( 263 name = "dataset_creator", 264 srcs = [ 265 "dataset_creator.py", 266 ], 267 srcs_version = "PY3", 268 deps = [ 269 "//tensorflow/python:util", 270 ], 271) 272 273tf_py_test( 274 name = "dataset_creator_test", 275 srcs = ["dataset_creator_test.py"], 276 python_version = "PY3", 277 tags = [ 278 "no_tfrt", # TODO(b/180537361): Reenable TFRT after the issue is resolved. 279 ], 280 deps = [ 281 ":dataset_creator", 282 "//tensorflow/python/distribute:multi_worker_test_base", 283 "//tensorflow/python/keras/engine", 284 "//tensorflow/python/keras/layers:core", 285 ], 286) 287 288tf_py_test( 289 name = "data_utils_test", 290 size = "medium", 291 srcs = ["data_utils_test.py"], 292 python_version = "PY3", 293 shard_count = 6, 294 tags = [ 295 "noasan", # times out 296 "notsan", 297 "optonly", # times out 298 ], 299 deps = [ 300 "//tensorflow/python:client_testlib", 301 "//tensorflow/python/keras", 302 "//third_party/py/numpy", 303 "@absl_py//absl/testing:parameterized", 304 ], 305) 306 307tf_py_test( 308 name = "generic_utils_test", 309 size = "small", 310 srcs = ["generic_utils_test.py"], 311 python_version = "PY3", 312 deps = [ 313 ":generic_utils", 314 "//tensorflow/python:client_testlib", 315 "//tensorflow/python/keras", 316 "@absl_py//absl/testing:parameterized", 317 ], 318) 319 320tf_py_test( 321 name = "version_utils_test", 322 size = "small", 323 srcs = ["version_utils_test.py"], 324 python_version = "PY3", 325 deps = [ 326 ":version_utils", 327 "//tensorflow/python:client_testlib", 328 "//tensorflow/python/keras", 329 "@absl_py//absl/testing:parameterized", 330 ], 331) 332 333tf_py_test( 334 name = "tf_utils_test", 335 size = "small", 336 srcs = ["tf_utils_test.py"], 337 python_version = "PY3", 338 deps = [ 339 ":tf_utils", 340 "//tensorflow/python:client_testlib", 341 "//tensorflow/python/keras", 342 "//tensorflow/python/keras:combinations", 343 ], 344) 345 346tf_py_test( 347 name = "composite_tensor_support_test", 348 size = "medium", 349 srcs = ["composite_tensor_support_test.py"], 350 python_version = "PY3", 351 shard_count = 8, 352 deps = [ 353 "//tensorflow/python:array_ops", 354 "//tensorflow/python:client_testlib", 355 "//tensorflow/python:dtypes", 356 "//tensorflow/python:framework_ops", 357 "//tensorflow/python:framework_test_lib", 358 "//tensorflow/python:math_ops", 359 "//tensorflow/python:sparse_ops", 360 "//tensorflow/python:sparse_tensor", 361 "//tensorflow/python/keras", 362 "//tensorflow/python/keras:engine", 363 "//tensorflow/python/keras/layers", 364 "//tensorflow/python/ops/ragged:ragged_tensor", 365 "//third_party/py/numpy", 366 "@absl_py//absl/testing:parameterized", 367 ], 368) 369 370tf_py_test( 371 name = "io_utils_test", 372 size = "small", 373 srcs = ["io_utils_test.py"], 374 python_version = "PY3", 375 tags = [ 376 "no_windows", # TODO: needs investigation on Windows 377 "notsan", 378 ], 379 deps = [ 380 "//tensorflow/python:client_testlib", 381 "//tensorflow/python/keras", 382 "//third_party/py/numpy", 383 "@absl_py//absl/testing:parameterized", 384 ], 385) 386 387tf_py_test( 388 name = "layer_utils_test", 389 size = "small", 390 srcs = ["layer_utils_test.py"], 391 python_version = "PY3", 392 deps = [ 393 ":layer_utils", 394 "//tensorflow/python:client_testlib", 395 "//tensorflow/python/training/tracking", 396 "//third_party/py/numpy", 397 ], 398) 399 400tf_py_test( 401 name = "np_utils_test", 402 size = "small", 403 srcs = ["np_utils_test.py"], 404 python_version = "PY3", 405 deps = [ 406 "//tensorflow/python:client_testlib", 407 "//tensorflow/python/keras", 408 "//third_party/py/numpy", 409 "@absl_py//absl/testing:parameterized", 410 ], 411) 412 413tf_py_test( 414 name = "kernelized_utils_test", 415 size = "small", 416 srcs = ["kernelized_utils_test.py"], 417 python_version = "PY3", 418 deps = [ 419 ":layer_utils", 420 "//tensorflow/python:client_testlib", 421 "//tensorflow/python:constant_op", 422 "//tensorflow/python:layers", 423 "@absl_py//absl/testing:parameterized", 424 ], 425) 426 427cuda_py_test( 428 name = "multi_gpu_utils_test", 429 srcs = ["multi_gpu_utils_test.py"], 430 python_version = "PY3", 431 tags = [ 432 "guitar", 433 "multi_gpu", 434 ], 435 xla_enable_strict_auto_jit = True, 436 deps = [ 437 "//tensorflow/python:client_testlib", 438 "//tensorflow/python/keras", 439 "//third_party/py/numpy", 440 "@absl_py//absl/testing:parameterized", 441 ], 442) 443 444tf_py_test( 445 name = "vis_utils_test", 446 size = "small", 447 srcs = ["vis_utils_test.py"], 448 python_version = "PY3", 449 deps = [ 450 "//tensorflow/python:client_testlib", 451 "//tensorflow/python/keras", 452 "//third_party/py/numpy", 453 "@absl_py//absl/testing:parameterized", 454 ], 455) 456 457tf_py_test( 458 name = "conv_utils_test", 459 size = "small", 460 srcs = ["conv_utils_test.py"], 461 python_version = "PY3", 462 deps = [ 463 "//tensorflow/python:client_testlib", 464 "//tensorflow/python/keras", 465 "//third_party/py/numpy", 466 "@absl_py//absl/testing:parameterized", 467 ], 468) 469 470tf_py_test( 471 name = "metrics_utils_test", 472 size = "small", 473 srcs = ["metrics_utils_test.py"], 474 python_version = "PY3", 475 deps = [ 476 "//tensorflow/python:constant_op", 477 "//tensorflow/python:framework_ops", 478 "//tensorflow/python:framework_test_lib", 479 "//tensorflow/python:ops", 480 "//tensorflow/python:platform_test", 481 "//tensorflow/python/eager:context", 482 "//tensorflow/python/keras", 483 "//tensorflow/python/keras:combinations", 484 "//tensorflow/python/ops/ragged:ragged_factory_ops", 485 "//tensorflow/python/ops/ragged:ragged_tensor", 486 "@absl_py//absl/testing:parameterized", 487 ], 488) 489 490tf_py_test( 491 name = "losses_utils_test", 492 size = "small", 493 srcs = ["losses_utils_test.py"], 494 python_version = "PY3", 495 deps = [ 496 "//tensorflow/python:constant_op", 497 "//tensorflow/python:framework_ops", 498 "//tensorflow/python:framework_test_lib", 499 "//tensorflow/python:ops", 500 "//tensorflow/python:platform_test", 501 "//tensorflow/python/eager:context", 502 "//tensorflow/python/keras", 503 "//tensorflow/python/keras:combinations", 504 "//tensorflow/python/ops/ragged:ragged_array_ops", 505 "//tensorflow/python/ops/ragged:ragged_concat_ops", 506 "//tensorflow/python/ops/ragged:ragged_factory_ops", 507 ], 508) 509