• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description: Utilities for TPU Operations
2
3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
4load(
5    "//tensorflow:tensorflow.bzl",
6    "if_libtpu",
7    "if_windows",
8    "tf_cc_test",
9)
10
11package(
12    default_visibility = [
13        "//tensorflow/compiler/mlir/tensorflow:__subpackages__",
14        "//tensorflow/compiler/tf2xla/kernels:__subpackages__",
15        "//tensorflow/compiler/xla:__subpackages__",
16        "//tensorflow/compiler/xrt:__subpackages__",
17        "//tensorflow/core/profiler/backends/tpu:__subpackages__",
18        "//tensorflow/core/tpu:__subpackages__",
19        "//tensorflow/dtensor:__subpackages__",
20        "//tensorflow/stream_executor/tpu:__subpackages__",
21    ],
22    licenses = ["notice"],
23)
24
25cc_library(
26    name = "libtftpu_header",
27    hdrs = ["libtftpu.h"],
28    visibility = ["//visibility:public"],
29    deps = [],
30)
31
32cc_library(
33    name = "tpu_embedding_configuration_utils",
34    srcs = ["tpu_embedding_configuration_utils.cc"],
35    hdrs = ["tpu_embedding_configuration_utils.h"],
36    visibility = ["//visibility:public"],
37    deps = [
38        "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
39        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
40        "@com_google_absl//absl/status:statusor",
41        "@com_google_absl//absl/strings:str_format",
42    ],
43)
44
45cc_library(
46    name = "tpu_embedding_errors",
47    srcs = ["tpu_embedding_errors.cc"],
48    hdrs = ["tpu_embedding_errors.h"],
49    deps = [
50        "//tensorflow/core/platform:status",
51        "//tensorflow/core/platform:statusor",
52        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
53        "@com_google_absl//absl/strings",
54    ],
55)
56
57tf_cc_test(
58    name = "tpu_embedding_errors_test",
59    srcs = ["tpu_embedding_errors_test.cc"],
60    deps = [
61        ":tpu_embedding_errors",
62        "//tensorflow/core:test",
63        "//tensorflow/core/platform:errors",
64        "@com_google_absl//absl/strings",
65        "@com_google_googletest//:gtest_main",
66    ],
67)
68
69cc_library(
70    name = "tpu_embedding_optimization_parameters_utils",
71    srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
72    hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
73    visibility = ["//visibility:public"],
74    deps = [
75        "//tensorflow/compiler/xla:xla_data_proto_cc",
76        "//tensorflow/compiler/xla/service:hlo",
77        "//tensorflow/compiler/xla/service:hlo_proto_cc",
78        "//tensorflow/core:framework",
79        "//tensorflow/core:lib",
80        "//tensorflow/core:lib_proto_parsing",
81        "//tensorflow/core:protos_all_cc",
82        "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
83        "@com_google_absl//absl/base",
84    ],
85)
86
87cc_library(
88    name = "tpu_embedding_output_layout_utils",
89    srcs = ["tpu_embedding_output_layout_utils.cc"],
90    hdrs = ["tpu_embedding_output_layout_utils.h"],
91    visibility = ["//visibility:public"],
92    deps = [
93        "//tensorflow/core/framework:tensor_shape_proto_cc",
94        "//tensorflow/core/lib/core:status",
95        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
96    ],
97)
98
99cc_library(
100    name = "tpu_embedding_configuration_proto_rewrite",
101    srcs = ["tpu_embedding_configuration_proto_rewrite.cc"],
102    hdrs = ["tpu_embedding_configuration_proto_rewrite.h"],
103    visibility = ["//visibility:public"],
104    deps = [
105        "//tensorflow/compiler/xla:status_macros",
106        "//tensorflow/core/lib/math:math_util",
107        "//tensorflow/core/platform:errors",
108        "//tensorflow/core/platform:status",
109        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
110        "@com_google_absl//absl/algorithm:container",
111        "@com_google_absl//absl/strings:str_format",
112    ],
113)
114
115tf_cc_test(
116    name = "tpu_embedding_configuration_proto_rewrite_test",
117    srcs = ["tpu_embedding_configuration_proto_rewrite_test.cc"],
118    deps = [
119        ":tpu_embedding_configuration_proto_rewrite",
120        "//tensorflow/core:framework_lite",
121        "//tensorflow/core:test",
122        "//tensorflow/core/lib/core:errors",
123        "//tensorflow/core/lib/core:status",
124        "//tensorflow/core/platform:casts",
125        "//tensorflow/core/platform:status",
126        "//tensorflow/core/platform:status_matchers",
127        "//tensorflow/core/protobuf:error_codes_proto_impl_cc",
128        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
129        "@com_google_absl//absl/strings",
130        "@com_google_googletest//:gtest_main",
131    ],
132)
133
134cc_library(
135    name = "tpu_node_device_util",
136    srcs = ["tpu_node_device_util.cc"],
137    hdrs = ["tpu_node_device_util.h"],
138    visibility = ["//visibility:public"],
139    deps = [
140        "//tensorflow/compiler/tf2xla:tf2xla_util",
141        "//tensorflow/core:lib",
142        "//tensorflow/core:protos_all_cc",
143    ],
144)
145
146cc_library(
147    name = "tpu_compile_interface",
148    srcs = ["tpu_compile_interface.cc"],
149    hdrs = ["tpu_compile_interface.h"],
150    deps = [
151        "//tensorflow/core:lib",
152        "@com_google_absl//absl/strings",
153    ],
154)
155
156cc_library(
157    name = "tpu_defs",
158    srcs = ["tpu_defs.cc"],
159    hdrs = ["tpu_defs.h"],
160    visibility = ["//visibility:public"],
161    deps = ["//tensorflow/core:protos_all_cc"],
162)
163
164cc_library(
165    name = "tpu_configuration",
166    srcs = ["tpu_configuration.cc"],
167    hdrs = ["tpu_configuration.h"],
168    deps = ["//tensorflow/core:framework"],
169)
170
171cc_library(
172    name = "tpu_init_mode",
173    srcs = ["tpu_init_mode.cc"],
174    hdrs = ["tpu_init_mode.h"],
175    deps = [
176        "//tensorflow/core:lib",
177    ],
178)
179
180cc_library(
181    name = "tpu_initializer_helper",
182    srcs = ["tpu_initializer_helper.cc"],
183    hdrs = ["tpu_initializer_helper.h"],
184    visibility = ["//visibility:public"],
185    deps = [
186        ":libtftpu_header",
187        ":tpu_api",
188        ":tpu_api_dlsym_set_fn",
189        ":tpu_executor_init_fns",
190        ":tpu_library_init_fns",
191        ":tpu_ops_c_api_hdrs",
192        "//tensorflow/core:lib",
193        "//tensorflow/core/platform:logging",
194        "//tensorflow/core/platform/cloud:gcs_file_system",
195        "//tensorflow/stream_executor/tpu:tpu_executor",
196        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
197        "@com_google_absl//absl/strings",
198        "@com_google_absl//absl/synchronization",
199    ],
200)
201
202cc_library(
203    name = "tpu_api",
204    srcs = ["tpu_api.cc"],
205    hdrs = ["tpu_api.h"],
206    visibility = ["//visibility:public"],
207    deps = [
208        ":libtftpu_header",
209        ":tpu_executor_api",
210        ":tpu_ops_c_api_hdrs",
211        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
212    ],
213)
214
215cc_library(
216    name = "tpu_executor_api",
217    srcs = ["tpu_executor_api.cc"],
218    hdrs = ["tpu_executor_api.h"],
219    deps = [
220        ":libtftpu_header",
221        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
222    ],
223)
224
225cc_library(
226    name = "pjrt_api",
227    srcs = ["pjrt_api.cc"],
228    hdrs = ["pjrt_api.h"],
229    deps = [
230        "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs",
231    ],
232)
233
234cc_library(
235    name = "tpu_api_dlsym_initializer",
236    srcs = if_windows(
237        ["tpu_api_dlsym_initializer_windows.cc"],
238        otherwise = ["tpu_api_dlsym_initializer.cc"],
239    ),
240    visibility = ["//visibility:public"],
241    deps = [
242        ":tpu_initializer_helper",
243    ],
244    # Always link this in, because even if we don't use it directly we want its
245    # static initializers to dynamically load API symbols exported from libtpu.so
246    alwayslink = True,
247)
248
249cc_library(
250    name = "tpu_api_dlsym_set_fn",
251    hdrs = ["tpu_api_dlsym_set_fn.h"],
252    visibility = ["//visibility:public"],
253)
254
255cc_library(
256    name = "tpu_library_init_fns",
257    hdrs = ["tpu_library_init_fns.inc"],
258    visibility = ["//visibility:public"],
259    deps = [":tpu_executor_init_fns"],
260)
261
262cc_library(
263    name = "tpu_executor_init_fns",
264    hdrs = ["tpu_executor_init_fns.inc"],
265    visibility = ["//visibility:public"],
266)
267
268cc_library(
269    name = "virtual_device",
270    srcs = ["virtual_device.cc"],
271    hdrs = ["virtual_device.h"],
272    visibility = ["//visibility:public"],
273    deps = [
274        "//tensorflow/core:core_cpu",
275        "//tensorflow/core:protos_all_cc",
276    ],
277)
278
279cc_library(
280    name = "tpu_compile",
281    srcs = ["tpu_compile.cc"],
282    hdrs = ["tpu_compile.h"],
283    deps = [
284        ":tpu_defs",
285        "//tensorflow/compiler/jit:flags_headers",
286        "//tensorflow/compiler/jit:shape_inference",
287        "//tensorflow/compiler/tf2xla:layout_util",
288        "//tensorflow/compiler/tf2xla:tf2xla_util",
289        "//tensorflow/compiler/tf2xla:xla_compiler",
290        "//tensorflow/compiler/xla:xla_data_proto_cc",
291        "//tensorflow/compiler/xla/client:compile_only_client",
292        "//tensorflow/core:core_cpu_base",
293        "//tensorflow/core:framework",
294        "//tensorflow/core/framework:attr_value_proto_cc",
295        "//tensorflow/core/framework:node_def_util",
296        "//tensorflow/core/framework:versions_proto_cc",
297        "//tensorflow/core/platform:statusor",
298        "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
299        "//tensorflow/core/tpu/kernels:tpu_compile_op_support",
300        "//tensorflow/core/tpu/kernels:tpu_util_hdrs",
301    ],
302)
303
304cc_library(
305    name = "tpu_execute",
306    srcs = ["tpu_execute.cc"],
307    hdrs = ["tpu_execute.h"],
308    deps = [
309        ":tpu_api",
310        "//tensorflow/compiler/xla:executable_run_options",
311        "//tensorflow/compiler/xla:shape_layout",
312        "//tensorflow/compiler/xla:shape_util",
313        "//tensorflow/compiler/xla:status",
314        "//tensorflow/compiler/xla:status_macros",
315        "//tensorflow/compiler/xla:statusor",
316        "//tensorflow/compiler/xla:util",
317        "//tensorflow/compiler/xla:xla_data_proto_cc",
318        "//tensorflow/compiler/xla/service:computation_layout",
319        "//tensorflow/compiler/xla/service:computation_placer",
320        "//tensorflow/compiler/xla/service:executable",
321        "//tensorflow/compiler/xla/service:hlo",
322        "//tensorflow/compiler/xla/service:hlo_module_config",
323        "//tensorflow/compiler/xla/service:hlo_proto_cc",
324        "//tensorflow/compiler/xla/service:maybe_owning_device_memory",
325        "//tensorflow/compiler/xla/service:transfer_manager",
326        "//tensorflow/core:framework",
327        "//tensorflow/core:lib",
328        "//tensorflow/core/profiler/lib:traceme",
329        "//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
330        "//tensorflow/core/tpu/kernels:tpu_executable_info_proto_cc",
331        "//tensorflow/core/tpu/kernels:tpu_execute_op_options",
332        "//tensorflow/stream_executor:device_memory",
333        "//tensorflow/stream_executor:stream",
334        "//tensorflow/stream_executor/lib",
335        "//tensorflow/stream_executor/tpu:c_api_conversions",
336        "//tensorflow/stream_executor/tpu:status_helper",
337        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
338        "//tensorflow/stream_executor/tpu:tpu_node_context",
339        "//tensorflow/stream_executor/tpu:tpu_op_executable",
340        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
341        "@com_google_absl//absl/base",
342        "@com_google_absl//absl/memory",
343    ],
344)
345
346cc_library(
347    name = "tpu_on_demand_compiler",
348    srcs = ["tpu_on_demand_compiler.cc"],
349    visibility = ["//visibility:public"],
350    deps = [
351        "//tensorflow/compiler/xla:shape_util",
352        "//tensorflow/compiler/xla:util",
353        "//tensorflow/compiler/xla:xla_data_proto_cc",
354        "//tensorflow/compiler/xla/service:compiler",
355        "//tensorflow/compiler/xla/service:executable",
356        "//tensorflow/compiler/xla/service:hlo",
357        "//tensorflow/compiler/xla/service:hlo_cost_analysis",
358        "//tensorflow/compiler/xla/service:hlo_module_group",
359        "//tensorflow/compiler/xla/service:shaped_buffer",
360        "//tensorflow/stream_executor:device_memory_allocator",
361        "//tensorflow/stream_executor/tpu:c_api_conversions",
362        "//tensorflow/stream_executor/tpu:c_api_decl",
363        "//tensorflow/stream_executor/tpu:proto_helper",
364        "//tensorflow/stream_executor/tpu:status_helper",
365        "//tensorflow/stream_executor/tpu:tpu_executable",
366        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
367        "//tensorflow/stream_executor/tpu:tpu_executor_hdrs",
368        "//tensorflow/stream_executor/tpu:tpu_platform_id",
369        "@com_google_absl//absl/cleanup",
370        "@com_google_absl//absl/types:span",
371    ],
372    alwayslink = True,
373)
374
375cc_library(
376    name = "tpu_runtime",
377    srcs = [],
378    visibility = ["//visibility:public"],
379    deps = [
380        ":tpu_api_dlsym_initializer",
381        "//tensorflow/core/tpu:tpu_on_demand_compiler",
382        "//tensorflow/core/tpu/ops",
383    ],
384)
385
386cc_library(
387    name = "tpu_ops_c_api_hdrs",
388    srcs = [],
389    hdrs = [
390        "tpu_ops_c_api.h",
391    ],
392    visibility = ["//visibility:public"],
393    deps = [
394        ":libtftpu_header",
395        "//tensorflow/c:tf_tstring",
396        "//tensorflow/compiler/xla/stream_executor/tpu:c_api_decl",
397        "//tensorflow/compiler/xla/stream_executor/tpu:proto_helper",
398        "@com_google_absl//absl/types:optional",
399    ],
400    alwayslink = True,
401)
402
403cc_library(
404    name = "tpu_fingerprint_utils",
405    srcs = ["tpu_fingerprint_utils.cc"],
406    hdrs = ["tpu_fingerprint_utils.h"],
407    deps = [
408        ":tpu_compile_interface",
409        "//tensorflow/compiler/xla:status_macros",
410        "//tensorflow/core:framework",
411        "//tensorflow/core/lib/core:status",
412        "//tensorflow/core/lib/strings:proto_serialization",
413    ],
414)
415
416cc_library(
417    name = "tpu_model_server_initializer",
418    srcs = ["tpu_model_server_initializer.cc"],
419    hdrs = ["tpu_model_server_initializer.h"],
420    visibility = ["//visibility:public"],
421    deps = [
422        ":libtftpu_header",
423        ":tpu_api",
424        ":tpu_api_dlsym_set_fn",
425        ":tpu_executor_init_fns",
426        ":tpu_initializer_helper",
427        ":tpu_library_init_fns",
428        ":tpu_ops_c_api_hdrs",
429        "//tensorflow/core:lib",
430        "//tensorflow/stream_executor/tpu:tpu_executor",
431        "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
432    ],
433    alwayslink = True,
434)
435
436cc_library(
437    name = "tpu_global_init",
438    srcs = ["tpu_global_init.cc"],
439    hdrs = ["tpu_global_init.h"],
440    visibility = ["//visibility:public"],
441    deps = [
442        ":tpu_defs",
443        "@com_google_absl//absl/memory",
444        "@com_google_absl//absl/strings",
445        "//tensorflow/cc:scope",
446        "//tensorflow/cc:tpu_ops",
447        "//tensorflow/core:core_cpu",
448        "//tensorflow/core:core_cpu_lib",
449        "//tensorflow/core:framework",
450        "//tensorflow/core:framework_internal",
451        "//tensorflow/core:lib",
452        "//tensorflow/core/protobuf/tpu:topology_proto_cc",
453        "//tensorflow/core/tpu/graph_rewrite:distributed_tpu_configuration_rewrite_pass",
454        "//tensorflow/core/tpu/graph_rewrite:distributed_tpu_rewrite_helpers",
455    ] + if_libtpu(["//tensorflow/compiler/jit"]),
456)
457