1# Description: Utilities for TPU Operations 2 3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") 4load( 5 "//tensorflow:tensorflow.bzl", 6 "if_libtpu", 7 "if_windows", 8 "tf_cc_test", 9) 10 11package( 12 default_visibility = [ 13 "//tensorflow/compiler/mlir/tensorflow:__subpackages__", 14 "//tensorflow/compiler/tf2xla/kernels:__subpackages__", 15 "//tensorflow/compiler/xla:__subpackages__", 16 "//tensorflow/compiler/xrt:__subpackages__", 17 "//tensorflow/core/profiler/backends/tpu:__subpackages__", 18 "//tensorflow/core/tpu:__subpackages__", 19 "//tensorflow/dtensor:__subpackages__", 20 "//tensorflow/stream_executor/tpu:__subpackages__", 21 ], 22 licenses = ["notice"], 23) 24 25cc_library( 26 name = "libtftpu_header", 27 hdrs = ["libtftpu.h"], 28 visibility = ["//visibility:public"], 29 deps = [], 30) 31 32cc_library( 33 name = "tpu_embedding_configuration_utils", 34 srcs = ["tpu_embedding_configuration_utils.cc"], 35 hdrs = ["tpu_embedding_configuration_utils.h"], 36 visibility = ["//visibility:public"], 37 deps = [ 38 "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", 39 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", 40 "@com_google_absl//absl/status:statusor", 41 "@com_google_absl//absl/strings:str_format", 42 ], 43) 44 45cc_library( 46 name = "tpu_embedding_errors", 47 srcs = ["tpu_embedding_errors.cc"], 48 hdrs = ["tpu_embedding_errors.h"], 49 deps = [ 50 "//tensorflow/core/platform:status", 51 "//tensorflow/core/platform:statusor", 52 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", 53 "@com_google_absl//absl/strings", 54 ], 55) 56 57tf_cc_test( 58 name = "tpu_embedding_errors_test", 59 srcs = ["tpu_embedding_errors_test.cc"], 60 deps = [ 61 ":tpu_embedding_errors", 62 "//tensorflow/core:test", 63 "//tensorflow/core/platform:errors", 64 "@com_google_absl//absl/strings", 65 "@com_google_googletest//:gtest_main", 66 ], 67) 68 69cc_library( 70 name = "tpu_embedding_optimization_parameters_utils", 71 srcs = ["tpu_embedding_optimization_parameters_utils.cc"], 72 hdrs = ["tpu_embedding_optimization_parameters_utils.h"], 73 visibility = ["//visibility:public"], 74 deps = [ 75 "//tensorflow/compiler/xla:xla_data_proto_cc", 76 "//tensorflow/compiler/xla/service:hlo", 77 "//tensorflow/compiler/xla/service:hlo_proto_cc", 78 "//tensorflow/core:framework", 79 "//tensorflow/core:lib", 80 "//tensorflow/core:lib_proto_parsing", 81 "//tensorflow/core:protos_all_cc", 82 "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", 83 "@com_google_absl//absl/base", 84 ], 85) 86 87cc_library( 88 name = "tpu_embedding_output_layout_utils", 89 srcs = ["tpu_embedding_output_layout_utils.cc"], 90 hdrs = ["tpu_embedding_output_layout_utils.h"], 91 visibility = ["//visibility:public"], 92 deps = [ 93 "//tensorflow/core/framework:tensor_shape_proto_cc", 94 "//tensorflow/core/lib/core:status", 95 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", 96 ], 97) 98 99cc_library( 100 name = "tpu_embedding_configuration_proto_rewrite", 101 srcs = ["tpu_embedding_configuration_proto_rewrite.cc"], 102 hdrs = ["tpu_embedding_configuration_proto_rewrite.h"], 103 visibility = ["//visibility:public"], 104 deps = [ 105 "//tensorflow/compiler/xla:status_macros", 106 "//tensorflow/core/lib/math:math_util", 107 "//tensorflow/core/platform:errors", 108 "//tensorflow/core/platform:status", 109 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", 110 "@com_google_absl//absl/algorithm:container", 111 "@com_google_absl//absl/strings:str_format", 112 ], 113) 114 115tf_cc_test( 116 name = "tpu_embedding_configuration_proto_rewrite_test", 117 srcs = ["tpu_embedding_configuration_proto_rewrite_test.cc"], 118 deps = [ 119 ":tpu_embedding_configuration_proto_rewrite", 120 "//tensorflow/core:framework_lite", 121 "//tensorflow/core:test", 122 "//tensorflow/core/lib/core:errors", 123 "//tensorflow/core/lib/core:status", 124 "//tensorflow/core/platform:casts", 125 "//tensorflow/core/platform:status", 126 "//tensorflow/core/platform:status_matchers", 127 "//tensorflow/core/protobuf:error_codes_proto_impl_cc", 128 "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", 129 "@com_google_absl//absl/strings", 130 "@com_google_googletest//:gtest_main", 131 ], 132) 133 134cc_library( 135 name = "tpu_node_device_util", 136 srcs = ["tpu_node_device_util.cc"], 137 hdrs = ["tpu_node_device_util.h"], 138 visibility = ["//visibility:public"], 139 deps = [ 140 "//tensorflow/compiler/tf2xla:tf2xla_util", 141 "//tensorflow/core:lib", 142 "//tensorflow/core:protos_all_cc", 143 ], 144) 145 146cc_library( 147 name = "tpu_compile_interface", 148 srcs = ["tpu_compile_interface.cc"], 149 hdrs = ["tpu_compile_interface.h"], 150 deps = [ 151 "//tensorflow/core:lib", 152 "@com_google_absl//absl/strings", 153 ], 154) 155 156cc_library( 157 name = "tpu_defs", 158 srcs = ["tpu_defs.cc"], 159 hdrs = ["tpu_defs.h"], 160 visibility = ["//visibility:public"], 161 deps = ["//tensorflow/core:protos_all_cc"], 162) 163 164cc_library( 165 name = "tpu_configuration", 166 srcs = ["tpu_configuration.cc"], 167 hdrs = ["tpu_configuration.h"], 168 deps = ["//tensorflow/core:framework"], 169) 170 171cc_library( 172 name = "tpu_init_mode", 173 srcs = ["tpu_init_mode.cc"], 174 hdrs = ["tpu_init_mode.h"], 175 deps = [ 176 "//tensorflow/core:lib", 177 ], 178) 179 180cc_library( 181 name = "tpu_initializer_helper", 182 srcs = ["tpu_initializer_helper.cc"], 183 hdrs = ["tpu_initializer_helper.h"], 184 visibility = ["//visibility:public"], 185 deps = [ 186 ":libtftpu_header", 187 ":tpu_api", 188 ":tpu_api_dlsym_set_fn", 189 ":tpu_executor_init_fns", 190 ":tpu_library_init_fns", 191 ":tpu_ops_c_api_hdrs", 192 "//tensorflow/core:lib", 193 "//tensorflow/core/platform:logging", 194 "//tensorflow/core/platform/cloud:gcs_file_system", 195 "//tensorflow/stream_executor/tpu:tpu_executor", 196 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 197 "@com_google_absl//absl/strings", 198 "@com_google_absl//absl/synchronization", 199 ], 200) 201 202cc_library( 203 name = "tpu_api", 204 srcs = ["tpu_api.cc"], 205 hdrs = ["tpu_api.h"], 206 visibility = ["//visibility:public"], 207 deps = [ 208 ":libtftpu_header", 209 ":tpu_executor_api", 210 ":tpu_ops_c_api_hdrs", 211 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 212 ], 213) 214 215cc_library( 216 name = "tpu_executor_api", 217 srcs = ["tpu_executor_api.cc"], 218 hdrs = ["tpu_executor_api.h"], 219 deps = [ 220 ":libtftpu_header", 221 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 222 ], 223) 224 225cc_library( 226 name = "pjrt_api", 227 srcs = ["pjrt_api.cc"], 228 hdrs = ["pjrt_api.h"], 229 deps = [ 230 "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs", 231 ], 232) 233 234cc_library( 235 name = "tpu_api_dlsym_initializer", 236 srcs = if_windows( 237 ["tpu_api_dlsym_initializer_windows.cc"], 238 otherwise = ["tpu_api_dlsym_initializer.cc"], 239 ), 240 visibility = ["//visibility:public"], 241 deps = [ 242 ":tpu_initializer_helper", 243 ], 244 # Always link this in, because even if we don't use it directly we want its 245 # static initializers to dynamically load API symbols exported from libtpu.so 246 alwayslink = True, 247) 248 249cc_library( 250 name = "tpu_api_dlsym_set_fn", 251 hdrs = ["tpu_api_dlsym_set_fn.h"], 252 visibility = ["//visibility:public"], 253) 254 255cc_library( 256 name = "tpu_library_init_fns", 257 hdrs = ["tpu_library_init_fns.inc"], 258 visibility = ["//visibility:public"], 259 deps = [":tpu_executor_init_fns"], 260) 261 262cc_library( 263 name = "tpu_executor_init_fns", 264 hdrs = ["tpu_executor_init_fns.inc"], 265 visibility = ["//visibility:public"], 266) 267 268cc_library( 269 name = "virtual_device", 270 srcs = ["virtual_device.cc"], 271 hdrs = ["virtual_device.h"], 272 visibility = ["//visibility:public"], 273 deps = [ 274 "//tensorflow/core:core_cpu", 275 "//tensorflow/core:protos_all_cc", 276 ], 277) 278 279cc_library( 280 name = "tpu_compile", 281 srcs = ["tpu_compile.cc"], 282 hdrs = ["tpu_compile.h"], 283 deps = [ 284 ":tpu_defs", 285 "//tensorflow/compiler/jit:flags_headers", 286 "//tensorflow/compiler/jit:shape_inference", 287 "//tensorflow/compiler/tf2xla:layout_util", 288 "//tensorflow/compiler/tf2xla:tf2xla_util", 289 "//tensorflow/compiler/tf2xla:xla_compiler", 290 "//tensorflow/compiler/xla:xla_data_proto_cc", 291 "//tensorflow/compiler/xla/client:compile_only_client", 292 "//tensorflow/core:core_cpu_base", 293 "//tensorflow/core:framework", 294 "//tensorflow/core/framework:attr_value_proto_cc", 295 "//tensorflow/core/framework:node_def_util", 296 "//tensorflow/core/framework:versions_proto_cc", 297 "//tensorflow/core/platform:statusor", 298 "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", 299 "//tensorflow/core/tpu/kernels:tpu_compile_op_support", 300 "//tensorflow/core/tpu/kernels:tpu_util_hdrs", 301 ], 302) 303 304cc_library( 305 name = "tpu_execute", 306 srcs = ["tpu_execute.cc"], 307 hdrs = ["tpu_execute.h"], 308 deps = [ 309 ":tpu_api", 310 "//tensorflow/compiler/xla:executable_run_options", 311 "//tensorflow/compiler/xla:shape_layout", 312 "//tensorflow/compiler/xla:shape_util", 313 "//tensorflow/compiler/xla:status", 314 "//tensorflow/compiler/xla:status_macros", 315 "//tensorflow/compiler/xla:statusor", 316 "//tensorflow/compiler/xla:util", 317 "//tensorflow/compiler/xla:xla_data_proto_cc", 318 "//tensorflow/compiler/xla/service:computation_layout", 319 "//tensorflow/compiler/xla/service:computation_placer", 320 "//tensorflow/compiler/xla/service:executable", 321 "//tensorflow/compiler/xla/service:hlo", 322 "//tensorflow/compiler/xla/service:hlo_module_config", 323 "//tensorflow/compiler/xla/service:hlo_proto_cc", 324 "//tensorflow/compiler/xla/service:maybe_owning_device_memory", 325 "//tensorflow/compiler/xla/service:transfer_manager", 326 "//tensorflow/core:framework", 327 "//tensorflow/core:lib", 328 "//tensorflow/core/profiler/lib:traceme", 329 "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", 330 "//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc", 331 "//tensorflow/core/tpu/kernels:tpu_execute_op_options", 332 "//tensorflow/stream_executor:device_memory", 333 "//tensorflow/stream_executor:stream", 334 "//tensorflow/stream_executor/lib", 335 "//tensorflow/stream_executor/tpu:c_api_conversions", 336 "//tensorflow/stream_executor/tpu:status_helper", 337 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 338 "//tensorflow/stream_executor/tpu:tpu_node_context", 339 "//tensorflow/stream_executor/tpu:tpu_op_executable", 340 "//tensorflow/stream_executor/tpu:tpu_platform_interface", 341 "@com_google_absl//absl/base", 342 "@com_google_absl//absl/memory", 343 ], 344) 345 346cc_library( 347 name = "tpu_on_demand_compiler", 348 srcs = ["tpu_on_demand_compiler.cc"], 349 visibility = ["//visibility:public"], 350 deps = [ 351 "//tensorflow/compiler/xla:shape_util", 352 "//tensorflow/compiler/xla:util", 353 "//tensorflow/compiler/xla:xla_data_proto_cc", 354 "//tensorflow/compiler/xla/service:compiler", 355 "//tensorflow/compiler/xla/service:executable", 356 "//tensorflow/compiler/xla/service:hlo", 357 "//tensorflow/compiler/xla/service:hlo_cost_analysis", 358 "//tensorflow/compiler/xla/service:hlo_module_group", 359 "//tensorflow/compiler/xla/service:shaped_buffer", 360 "//tensorflow/stream_executor:device_memory_allocator", 361 "//tensorflow/stream_executor/tpu:c_api_conversions", 362 "//tensorflow/stream_executor/tpu:c_api_decl", 363 "//tensorflow/stream_executor/tpu:proto_helper", 364 "//tensorflow/stream_executor/tpu:status_helper", 365 "//tensorflow/stream_executor/tpu:tpu_executable", 366 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 367 "//tensorflow/stream_executor/tpu:tpu_executor_hdrs", 368 "//tensorflow/stream_executor/tpu:tpu_platform_id", 369 "@com_google_absl//absl/cleanup", 370 "@com_google_absl//absl/types:span", 371 ], 372 alwayslink = True, 373) 374 375cc_library( 376 name = "tpu_runtime", 377 srcs = [], 378 visibility = ["//visibility:public"], 379 deps = [ 380 ":tpu_api_dlsym_initializer", 381 "//tensorflow/core/tpu:tpu_on_demand_compiler", 382 "//tensorflow/core/tpu/ops", 383 ], 384) 385 386cc_library( 387 name = "tpu_ops_c_api_hdrs", 388 srcs = [], 389 hdrs = [ 390 "tpu_ops_c_api.h", 391 ], 392 visibility = ["//visibility:public"], 393 deps = [ 394 ":libtftpu_header", 395 "//tensorflow/c:tf_tstring", 396 "//tensorflow/compiler/xla/stream_executor/tpu:c_api_decl", 397 "//tensorflow/compiler/xla/stream_executor/tpu:proto_helper", 398 "@com_google_absl//absl/types:optional", 399 ], 400 alwayslink = True, 401) 402 403cc_library( 404 name = "tpu_fingerprint_utils", 405 srcs = ["tpu_fingerprint_utils.cc"], 406 hdrs = ["tpu_fingerprint_utils.h"], 407 deps = [ 408 ":tpu_compile_interface", 409 "//tensorflow/compiler/xla:status_macros", 410 "//tensorflow/core:framework", 411 "//tensorflow/core/lib/core:status", 412 "//tensorflow/core/lib/strings:proto_serialization", 413 ], 414) 415 416cc_library( 417 name = "tpu_model_server_initializer", 418 srcs = ["tpu_model_server_initializer.cc"], 419 hdrs = ["tpu_model_server_initializer.h"], 420 visibility = ["//visibility:public"], 421 deps = [ 422 ":libtftpu_header", 423 ":tpu_api", 424 ":tpu_api_dlsym_set_fn", 425 ":tpu_executor_init_fns", 426 ":tpu_initializer_helper", 427 ":tpu_library_init_fns", 428 ":tpu_ops_c_api_hdrs", 429 "//tensorflow/core:lib", 430 "//tensorflow/stream_executor/tpu:tpu_executor", 431 "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", 432 ], 433 alwayslink = True, 434) 435 436cc_library( 437 name = "tpu_global_init", 438 srcs = ["tpu_global_init.cc"], 439 hdrs = ["tpu_global_init.h"], 440 visibility = ["//visibility:public"], 441 deps = [ 442 ":tpu_defs", 443 "@com_google_absl//absl/memory", 444 "@com_google_absl//absl/strings", 445 "//tensorflow/cc:scope", 446 "//tensorflow/cc:tpu_ops", 447 "//tensorflow/core:core_cpu", 448 "//tensorflow/core:core_cpu_lib", 449 "//tensorflow/core:framework", 450 "//tensorflow/core:framework_internal", 451 "//tensorflow/core:lib", 452 "//tensorflow/core/protobuf/tpu:topology_proto_cc", 453 "//tensorflow/core/tpu/graph_rewrite:distributed_tpu_configuration_rewrite_pass", 454 "//tensorflow/core/tpu/graph_rewrite:distributed_tpu_rewrite_helpers", 455 ] + if_libtpu(["//tensorflow/compiler/jit"]), 456) 457