1# Description: Operations defined for Cloud TPUs 2 3load("//tensorflow:tensorflow.bzl", "tf_py_test") 4load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") 5load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") 6 7# Do not add anymore paths here. You do not need to be in the visibility list 8# to use TPU symbols. They are accessible from tf.contrib.tpu in TF 1.x and 9# tf.tpu and tf.compat.v1.tpu in TF 2.x. 10package( 11 default_visibility = [ 12 "//learning/brain:__subpackages__", 13 "//learning/deepmind:__subpackages__", 14 "//research/graph:__subpackages__", 15 "//tensorflow:__subpackages__", 16 ], 17 licenses = ["notice"], # Apache 2.0 18) 19 20exports_files(["tpu_test_wrapper.py"]) 21 22py_test( 23 name = "tpu_test_wrapper_test", 24 srcs = [ 25 "tpu_test_wrapper.py", 26 "tpu_test_wrapper_test.py", 27 ], 28 main = "tpu_test_wrapper_test.py", 29 python_version = "PY3", 30 srcs_version = "PY3", 31 tags = [ 32 "no_oss_py2", 33 "no_oss_py35", 34 "no_pip", 35 ], 36 deps = [ 37 "//tensorflow/python:client_testlib", 38 "//tensorflow/python:platform", 39 "@absl_py//absl/testing:flagsaver", 40 ], 41) 42 43py_library( 44 name = "tpu_py", 45 srcs = ["ops/tpu_ops.py"], 46 srcs_version = "PY2AND3", 47 deps = [ 48 "//tensorflow/python:framework_for_generated_wrappers", 49 "//tensorflow/python:tpu_ops_gen", 50 ], 51) 52 53py_library( 54 name = "async_checkpoint", 55 srcs = ["async_checkpoint.py"], 56 srcs_version = "PY2AND3", 57 deps = [ 58 "//tensorflow/python:array_ops", 59 "//tensorflow/python:control_flow_ops", 60 "//tensorflow/python:framework_for_generated_wrappers", 61 "//tensorflow/python:init_ops", 62 "//tensorflow/python:math_ops", 63 "//tensorflow/python:platform", 64 "//tensorflow/python:state_ops", 65 "//tensorflow/python:summary", 66 "//tensorflow/python:summary_ops_v2", 67 "//tensorflow/python:training", 68 "//tensorflow/python:variable_scope", 69 "//tensorflow/python:variables", 70 "//tensorflow/python/estimator:estimator_py", 71 ], 72) 73 74tpu_py_test( 75 name = "async_checkpoint_test", 76 size = "medium", 77 srcs = ["async_checkpoint_test.py"], 78 disable_experimental = True, 79 deps = [ 80 ":async_checkpoint", 81 ":tpu_estimator", 82 ":tpu_lib", 83 "//tensorflow/python:lib", 84 "//tensorflow/python:platform", 85 "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", 86 "//third_party/py/numpy", 87 ], 88) 89 90py_library( 91 name = "preempted_hook_py", 92 srcs = ["preempted_hook.py"], 93 srcs_version = "PY2AND3", 94 deps = [ 95 "//tensorflow/python:errors", 96 "//tensorflow/python:platform", 97 "//tensorflow/python:session_run_hook", 98 "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", 99 ], 100) 101 102py_library( 103 name = "tpu_estimator", 104 srcs = [ 105 "_tpu_estimator_embedding.py", 106 "error_handling.py", 107 "tpu_config.py", 108 "tpu_context.py", 109 "tpu_estimator.py", 110 "util.py", 111 ], 112 srcs_version = "PY2AND3", 113 deps = [ 114 ":async_checkpoint", 115 ":feature_column", 116 ":feature_column_v2", 117 ":functional", 118 ":preempted_hook_py", 119 ":tpu_embedding", 120 ":tpu_lib", 121 "//tensorflow/core:protos_all_py", 122 "//tensorflow/python:array_ops", 123 "//tensorflow/python:control_flow_ops", 124 "//tensorflow/python:framework_for_generated_wrappers", 125 "//tensorflow/python:function", 126 "//tensorflow/python:init_ops", 127 "//tensorflow/python:math_ops", 128 "//tensorflow/python:platform", 129 "//tensorflow/python:session", 130 "//tensorflow/python:state_ops", 131 "//tensorflow/python:summary", 132 "//tensorflow/python:summary_ops_v2", 133 "//tensorflow/python:training", 134 "//tensorflow/python:variable_scope", 135 "//tensorflow/python:variables", 136 "//tensorflow/python/estimator:estimator_py", 137 "//tensorflow/python/estimator:util", 138 "@six_archive//:six", 139 ], 140) 141 142py_library( 143 name = "functional", 144 srcs = ["functional.py"], 145 srcs_version = "PY2AND3", 146 visibility = [ 147 "//visibility:public", 148 ], 149 deps = [ 150 "//tensorflow/python:tpu_ops_gen", 151 ], 152) 153 154py_library( 155 name = "tpu", 156 srcs = [ 157 "__init__.py", 158 ], 159 srcs_version = "PY2AND3", 160 deps = [ 161 ":feature_column", 162 ":feature_column_v2", 163 ":tpu_embedding", 164 ":tpu_estimator", 165 ":tpu_lib", 166 ], 167) 168 169py_library( 170 name = "tpu_noestimator", 171 srcs = [ 172 "__init__.py", 173 "api.py", 174 ], 175 srcs_version = "PY2AND3", 176 deps = [ 177 ":feature_column", 178 ":feature_column_v2", 179 ":preempted_hook_py", 180 ":tpu_embedding", 181 ":tpu_lib", 182 ], 183) 184 185py_library( 186 name = "tpu_lib", 187 srcs = [ 188 "__init__.py", 189 "bfloat16.py", 190 "device_assignment.py", 191 "session_support.py", 192 "tensor_tracer.py", 193 "tensor_tracer_flags.py", 194 "tensor_tracer_report.py", 195 "topology.py", 196 "tpu.py", 197 "tpu_feed.py", 198 "tpu_function.py", 199 "tpu_optimizer.py", 200 "tpu_sharding.py", 201 "tpu_strategy_util.py", 202 "tpu_system_metadata.py", 203 "training_loop.py", 204 ], 205 srcs_version = "PY2AND3", 206 deps = [ 207 ":datasets", 208 ":functional", 209 ":tpu_py", 210 "//tensorflow/compiler/xla/experimental/xla_sharding", 211 "//tensorflow/compiler/xla/python_api:xla_shape", 212 "//tensorflow/core:protos_all_py", 213 "//tensorflow/core/protobuf/tpu:compilation_result_proto_py", 214 "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py", 215 "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py", 216 "//tensorflow/core/protobuf/tpu:topology_proto_py", 217 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py", 218 "//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py", 219 "//tensorflow/python:array_ops", 220 "//tensorflow/python:batch_ops", 221 "//tensorflow/python:control_flow_ops", 222 "//tensorflow/python:control_flow_util", 223 "//tensorflow/python:dtypes", 224 "//tensorflow/python:framework", 225 "//tensorflow/python:framework_ops", 226 "//tensorflow/python:platform_analytics", 227 "//tensorflow/python:tensor_shape", 228 "//tensorflow/python:tpu_ops_gen", 229 "//tensorflow/python:training", 230 "//tensorflow/python:util", 231 "//tensorflow/python:variable_scope", 232 "//tensorflow/python/compiler/xla", 233 "//tensorflow/python/ops/losses", 234 "//tensorflow/python/tpu:tensor_tracer_proto_py", 235 "//tensorflow/python/tpu/profiler", 236 "@six_archive//:six", 237 ], 238) 239 240py_library( 241 name = "datasets", 242 srcs = [ 243 "datasets.py", 244 ], 245 srcs_version = "PY2AND3", 246 deps = [ 247 "//tensorflow/python:dtypes", 248 "//tensorflow/python:function", 249 "//tensorflow/python:functional_ops", 250 "//tensorflow/python/data/ops:dataset_ops", 251 "//tensorflow/python/data/ops:iterator_ops", 252 "//tensorflow/python/data/ops:readers", 253 ], 254) 255 256tf_py_test( 257 name = "datasets_test", 258 size = "medium", 259 srcs = ["datasets_test.py"], 260 grpc_enabled = True, 261 shard_count = 4, 262 tags = ["no_oss"], 263 deps = [ 264 ":datasets", 265 "//tensorflow/python:client_testlib", 266 ], 267) 268 269tf_py_test( 270 name = "tpu_test", 271 size = "small", 272 srcs = ["tpu_test.py"], 273 tags = [ 274 "no_oss", # TODO(b/131157871): Reenable in OSS when fixed 275 "no_windows", # TODO: needs investigation on Windows 276 ], 277 deps = [ 278 ":tpu", 279 "//tensorflow/python:client_testlib", 280 "//tensorflow/python:dtypes", 281 "//tensorflow/python:framework", 282 "//tensorflow/python:layers", 283 ], 284) 285 286tf_py_test( 287 name = "tpu_sharding_test", 288 size = "small", 289 srcs = ["tpu_sharding_test.py"], 290 deps = [ 291 ":tpu", 292 "//tensorflow/python:client_testlib", 293 "//tensorflow/python:framework", 294 ], 295) 296 297tf_py_test( 298 name = "bfloat16_test", 299 size = "small", 300 srcs = ["bfloat16_test.py"], 301 deps = [ 302 ":tpu", 303 "//tensorflow/python:client_testlib", 304 "//tensorflow/python:framework", 305 ], 306) 307 308tf_py_test( 309 name = "tpu_infeed_test", 310 size = "small", 311 srcs = ["tpu_infeed_test.py"], 312 deps = [ 313 ":tpu", 314 "//tensorflow/python:framework", 315 "//tensorflow/python:framework_test_lib", 316 ], 317) 318 319tf_py_test( 320 name = "topology_test", 321 size = "medium", 322 srcs = ["topology_test.py"], 323 deps = [ 324 ":tpu", 325 "//tensorflow/python:framework_test_lib", 326 ], 327) 328 329py_library( 330 name = "tpu_embedding", 331 srcs = [ 332 "tpu_embedding.py", 333 "tpu_embedding_gradient.py", 334 ], 335 srcs_version = "PY2AND3", 336 deps = [ 337 ":tpu_lib", 338 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py", 339 "//tensorflow/python:array_ops", 340 "//tensorflow/python:framework_for_generated_wrappers", 341 "//tensorflow/python:init_ops", 342 "//tensorflow/python:math_ops", 343 "//tensorflow/python:partitioned_variables", 344 "//tensorflow/python:tpu_ops_gen", 345 "//tensorflow/python:variable_scope", 346 "//tensorflow/python:variables", 347 "@six_archive//:six", 348 ], 349) 350 351py_library( 352 name = "tpu_strategy_util", 353 srcs = ["tpu_strategy_util.py"], 354 deps = [ 355 ":tpu_lib", 356 "//tensorflow/python:dtypes", 357 "//tensorflow/python:framework_ops", 358 "//tensorflow/python:util", 359 "//tensorflow/python/distribute:device_util", 360 "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", 361 "//tensorflow/python/eager:context", 362 "//tensorflow/python/eager:tape", 363 ], 364) 365 366py_library( 367 name = "feature_column", 368 srcs = ["feature_column.py"], 369 deps = [ 370 ":tpu_lib", 371 "//tensorflow/python:framework_ops", 372 "//tensorflow/python:init_ops", 373 "//tensorflow/python:variable_scope", 374 "//tensorflow/python/feature_column", 375 "//tensorflow/python/feature_column:feature_column_py", 376 ], 377) 378 379py_library( 380 name = "feature_column_v2", 381 srcs = ["feature_column_v2.py"], 382 deps = [ 383 ":feature_column", 384 ":tpu_lib", 385 "//tensorflow/python:framework_ops", 386 "//tensorflow/python:init_ops", 387 "//tensorflow/python:variable_scope", 388 "//tensorflow/python/feature_column", 389 "//tensorflow/python/feature_column:feature_column_py", 390 ], 391) 392 393tf_py_test( 394 name = "feature_column_test", 395 srcs = [ 396 "feature_column_test.py", 397 ], 398 main = "feature_column_test.py", 399 deps = [ 400 ":feature_column", 401 "//tensorflow/python:client_testlib", 402 "//tensorflow/python:dtypes", 403 "//tensorflow/python:framework_ops", 404 "//tensorflow/python:lookup_ops", 405 "//tensorflow/python:parsing_ops", 406 "//tensorflow/python:session", 407 "//tensorflow/python:sparse_tensor", 408 "//tensorflow/python:variables", 409 "//tensorflow/python/feature_column", 410 "//tensorflow/python/feature_column:feature_column_py", 411 "//third_party/py/numpy", 412 ], 413) 414 415tf_py_test( 416 name = "feature_column_v2_test", 417 srcs = [ 418 "feature_column_v2_test.py", 419 ], 420 main = "feature_column_v2_test.py", 421 deps = [ 422 ":feature_column_v2", 423 "//tensorflow/python:client_testlib", 424 "//tensorflow/python:dtypes", 425 "//tensorflow/python:framework_ops", 426 "//tensorflow/python:lookup_ops", 427 "//tensorflow/python:parsing_ops", 428 "//tensorflow/python:session", 429 "//tensorflow/python:sparse_tensor", 430 "//tensorflow/python:variables", 431 "//tensorflow/python/feature_column", 432 "//tensorflow/python/feature_column:feature_column_py", 433 "//third_party/py/numpy", 434 ], 435) 436 437tf_proto_library( 438 name = "tensor_tracer_proto", 439 srcs = ["tensor_tracer.proto"], 440 cc_api_version = 2, 441 protodeps = [ 442 "//tensorflow/core:protos_all", 443 ], 444 visibility = ["//visibility:public"], 445) 446