• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
2
3# buildifier: disable=same-origin-load
4load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "tf_cc_test")
5
6# buildifier: disable=same-origin-load
7load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_copts")
8load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
9
10# buildifier: disable=same-origin-load
11load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
12
13# buildifier: disable=same-origin-load
14load("//tensorflow:tensorflow.bzl", "filegroup")
15
16# buildifier: disable=same-origin-load
17load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps")
18load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library")
19load(
20    "//tensorflow/core/platform:build_config_root.bzl",
21    "if_static",
22    "tf_cuda_tests_tags",
23)
24
25package(
26    default_visibility = [
27        ":internal",
28        "//third_party/cloud_tpu/inference_converter:__pkg__",
29    ],
30    licenses = ["notice"],
31)
32
33package_group(
34    name = "internal",
35    includes = [
36        "//tensorflow/compiler/tf2xla:internal",
37    ],
38    packages = [
39        "//tensorflow/c/...",
40        "//tensorflow/compiler/tests/...",
41        "//tensorflow/python/...",
42    ],
43)
44
45package_group(
46    name = "friends",
47    includes = [
48        "//tensorflow/compiler/tf2xla:friends",
49    ],
50)
51
52# defs.cc/h only contains string constants, and can be included in mobile
53# builds.
54filegroup(
55    name = "mobile_srcs_no_runtime",
56    srcs = [
57        "defs.cc",
58        "defs.h",
59    ],
60    visibility = [":friends"],
61)
62
63# Target that bundles up the XLA CPU and GPU JIT devices.
64cc_library(
65    name = "jit",
66    visibility = [
67        ":friends",
68        "//learning/tfx:__subpackages__",
69    ],
70    deps = [
71        ":xla_cpu_device",
72        ":xla_cpu_jit",
73        "//tensorflow/compiler/plugin",
74    ] + if_cuda_or_rocm([
75        ":xla_gpu_device",
76        ":xla_gpu_jit",
77    ]) + if_with_tpu_support([
78        ":xla_tpu_device",
79        ":xla_tpu_jit",
80    ]),
81    alwayslink = 1,
82)
83
84cc_library(
85    name = "xla_cpu_jit",
86    visibility = ["//visibility:public"],
87    deps = [
88        ":jit_compilation_passes",
89        ":xla_kernel_creator",  # buildcleaner: keep
90        "//tensorflow/compiler/jit/kernels:xla_ops",
91        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
92        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
93    ] + if_libtpu(
94        if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
95        if_true = [],
96    ),
97    alwayslink = 1,
98)
99
100cc_library(
101    name = "xla_gpu_jit",
102    visibility = ["//visibility:public"],
103    deps = if_cuda_or_rocm([
104        ":jit_compilation_passes",
105        ":xla_kernel_creator",  # buildcleaner: keep
106        "//tensorflow/compiler/jit/kernels:xla_ops",
107        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
108        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
109        "//tensorflow/compiler/xla/service:gpu_plugin",
110    ]),
111    alwayslink = 1,
112)
113
114cc_library(
115    name = "xla_tpu_jit",
116    visibility = ["//visibility:public"],
117    deps = if_libtpu([
118        "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration",
119        "//tensorflow/core/tpu/graph_rewrite:configure_tpu_embedding_rewrite_registration",
120        "//tensorflow/stream_executor/tpu:tpu_transfer_manager",
121    ]),
122    alwayslink = 1,
123)
124
125cc_library(
126    name = "xla_cpu_device",
127    srcs = ["xla_cpu_device.cc"],
128    visibility = [":friends"],
129    deps = [
130        ":common",
131        ":flags",
132        ":jit_compilation_passes",
133        ":xla_device",
134        ":xla_kernel_creator",  # buildcleaner: keep
135        "@com_google_absl//absl/memory",
136        "//tensorflow/compiler/jit/kernels:xla_ops",
137        "//tensorflow/compiler/tf2xla:layout_util",
138        "//tensorflow/compiler/tf2xla:xla_compiler",
139        "//tensorflow/compiler/tf2xla:xla_op_registry",
140        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
141        "//tensorflow/core:core_cpu_internal",
142        "//tensorflow/core:lib",
143    ] + if_libtpu(
144        if_false = [
145            "//tensorflow/compiler/xla/service:cpu_plugin",  # buildcleaner: keep
146        ],
147        if_true = [],
148    ),
149    alwayslink = 1,
150)
151
152cc_library(
153    name = "xla_gpu_device",
154    srcs = ["xla_gpu_device.cc"],
155    visibility = [":friends"],
156    deps = [
157        ":common",
158        ":flags",
159        ":jit_compilation_passes",
160        ":xla_device",
161        ":xla_kernel_creator",  # buildcleaner: keep
162        ":xla_device_no_jit_rewrite_registration",
163        "@com_google_absl//absl/memory",
164        "@com_google_absl//absl/strings",
165        "//tensorflow/compiler/jit/kernels:xla_ops",
166        "//tensorflow/compiler/tf2xla:layout_util",
167        "//tensorflow/compiler/tf2xla:xla_compiler",
168        "//tensorflow/compiler/tf2xla:xla_op_registry",
169        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
170        "//tensorflow/core:core_cpu_internal",
171        "//tensorflow/core:lib",
172        "//tensorflow/core/common_runtime/gpu:gpu_init",
173    ] + if_libtpu(
174        if_false = [
175            "//tensorflow/compiler/xla/service:gpu_plugin",  # buildcleaner: keep
176        ],
177        if_true = [],
178    ),
179    alwayslink = 1,
180)
181
182cc_library(
183    name = "xla_tpu_device",
184    srcs = ["xla_tpu_device.cc"],
185    hdrs = ["xla_tpu_device.h"],
186    visibility = [":friends"],
187    deps = [
188        ":xla_device",
189        ":xla_kernel_creator",  # buildcleaner: keep
190        "@com_google_absl//absl/types:optional",
191        "//tensorflow/compiler/jit/kernels:xla_ops",
192        "//tensorflow/compiler/tf2xla:common",
193        "//tensorflow/compiler/tf2xla:layout_util",
194        "//tensorflow/compiler/tf2xla:tf2xla_util",
195        "//tensorflow/compiler/tf2xla:xla_helpers",
196        "//tensorflow/compiler/tf2xla:xla_op_registry",
197        "//tensorflow/core:framework_internal",
198        "//tensorflow/core:lib_proto_parsing",
199        "//tensorflow/core:protos_all_cc",
200        "//tensorflow/core:session_options",
201        "//tensorflow/core/common_runtime:device",
202        "//tensorflow/core/common_runtime:device_factory",
203        "//tensorflow/core/common_runtime:dma_helper",
204        "//tensorflow/core/platform:status",
205        "//tensorflow/core/tpu:tpu_api",
206        "//tensorflow/core/tpu:tpu_defs",
207        "//tensorflow/core/tpu:tpu_node_device_util",
208        "//tensorflow/core/tpu:virtual_device",
209        "//tensorflow/stream_executor/tpu:c_api_conversions",
210        "//tensorflow/stream_executor/tpu:status_helper",
211        "//tensorflow/stream_executor/tpu:tpu_executor_base",
212        "//tensorflow/stream_executor/tpu:tpu_node_context",
213        "//tensorflow/stream_executor/tpu:tpu_platform_interface",
214        "//tensorflow/stream_executor/tpu:tpu_stream_interface",
215    ] + if_static([
216        "//tensorflow/core/common_runtime:copy_tensor",
217        ":jit_compilation_passes",
218    ]),
219    alwayslink = 1,
220)
221
222cc_library(
223    name = "xla_tensor",
224    srcs = ["xla_tensor.cc"],
225    hdrs = ["xla_tensor.h"],
226    visibility = [":friends"],
227    deps = [
228        "//tensorflow/compiler/tf2xla:common",
229        "//tensorflow/compiler/xla:shape_util",
230        "//tensorflow/compiler/xla/client:local_client",
231        "//tensorflow/compiler/xla/service:shaped_buffer",
232        "//tensorflow/core:core_cpu_internal",
233        "//tensorflow/core:framework",
234        "//tensorflow/core:lib",
235        "//tensorflow/core:lib_internal",
236        "@com_google_absl//absl/memory",
237    ],
238)
239
240XLA_DEVICE_DEPS = [
241    ":common",
242    ":xla_launch_util",
243    ":xla_tensor",
244    "@com_google_absl//absl/base",
245    "@com_google_absl//absl/memory",
246    "@com_google_absl//absl/strings",
247    "@com_google_absl//absl/synchronization",
248    "@com_google_absl//absl/types:optional",
249    "//tensorflow/compiler/jit/ops:xla_ops",
250    "//tensorflow/compiler/tf2xla:layout_util",
251    "//tensorflow/compiler/tf2xla:common",
252    "//tensorflow/compiler/tf2xla:tf2xla_util",
253    "//tensorflow/compiler/tf2xla:xla_compiler",
254    "//tensorflow/compiler/tf2xla:xla_op_registry",
255    "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
256    "//tensorflow/compiler/tf2xla/kernels:xla_ops",
257    "//tensorflow/compiler/xla:util",
258    "//tensorflow/compiler/xla/client:client_library",
259    "//tensorflow/compiler/xla/client:global_data",
260    "//tensorflow/compiler/xla/client:local_client",
261    "//tensorflow/compiler/xla/service:stream_pool",
262    "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
263    "//tensorflow/core:array_ops_op_lib",
264    "//tensorflow/core:control_flow_ops_op_lib",
265    "//tensorflow/core:core_cpu",
266    "//tensorflow/core:core_cpu_internal",
267    "//tensorflow/core:dataset_ops_op_lib",
268    "//tensorflow/core:framework",
269    "//tensorflow/core:framework_internal",
270    "//tensorflow/core:functional_ops_op_lib",
271    "//tensorflow/core:lib",
272    "//tensorflow/core:lib_internal",
273    "//tensorflow/core:math_ops_op_lib",
274    "//tensorflow/core:nn_ops_op_lib",
275    "//tensorflow/core:no_op_op_lib",
276    "//tensorflow/core:protos_all_cc",
277    "//tensorflow/core:resource_variable_ops_op_lib",
278    "//tensorflow/core:sendrecv_ops_op_lib",
279    "//tensorflow/core:state_ops_op_lib",
280    "//tensorflow/core/platform:stream_executor_no_cuda",
281    "//tensorflow/core/kernels:constant_op",
282    "//tensorflow/core/kernels:fifo_queue",
283    "//tensorflow/core/kernels:function_ops",
284    "//tensorflow/core/kernels:identity_op",
285    "//tensorflow/core/kernels:resource_variable_ops",
286    "//tensorflow/core/kernels:shape_ops",
287    "//tensorflow/core/kernels:variable_ops",
288    "//tensorflow/core/kernels/data:finalize_dataset_op",
289    "//tensorflow/core/kernels/data:generator_dataset_op",
290    "//tensorflow/core/kernels/data:iterator_ops",
291    "//tensorflow/core/kernels/data:optional_ops",
292    "//tensorflow/core/kernels/data:prefetch_dataset_op",
293    "//tensorflow/core/kernels/data:options_dataset_op",
294    "//tensorflow/core/profiler/lib:traceme",
295    "//tensorflow/stream_executor:tf_allocator_adapter",
296    "//tensorflow/stream_executor/platform",
297]
298
299cc_library(
300    name = "xla_device_no_jit_rewrite_registration",
301    srcs = [
302        "xla_compile_on_demand_op.cc",
303        "xla_device.cc",
304        "xla_device_context.cc",
305        "xla_device_ops.cc",
306        "xla_ops_on_regular_devices.cc",
307        "xla_platform_info.cc",
308    ],
309    hdrs = [
310        "xla_compile_on_demand_op.h",
311        "xla_device.h",
312        "xla_device_context.h",
313        "xla_device_ops.h",
314        "xla_platform_info.h",
315    ],
316    # Public visibility is needed for external TF/XLA backends.
317    visibility = ["//visibility:public"],
318    deps = XLA_DEVICE_DEPS + [
319        ":flags_headers",
320        ":xla_compilation_cache",
321    ],
322    alwayslink = 1,
323)
324
325cc_library(
326    name = "xla_device",
327    hdrs = [
328        "xla_compile_on_demand_op.h",
329        "xla_device.h",
330        "xla_device_context.h",
331        "xla_device_ops.h",
332    ],
333    # Public visibility is needed for external TF/XLA backends.
334    visibility = ["//visibility:public"],
335    deps = XLA_DEVICE_DEPS + [
336        ":jit_compilation_passes",
337        ":xla_device_no_jit_rewrite_registration",
338    ],
339)
340
341cc_library(
342    name = "shape_inference_helpers",
343    srcs = ["shape_inference_helpers.cc"],
344    hdrs = ["shape_inference_helpers.h"],
345    visibility = [":friends"],
346    deps = select({
347        "//tensorflow:android": [
348            "//tensorflow/core:portable_tensorflow_lib",
349        ],
350        "//conditions:default": [
351            "//tensorflow/core:graph",
352        ],
353    }),
354)
355
356cc_library(
357    name = "flags",
358    srcs = ["flags.cc"],
359    hdrs = ["flags.h"],
360    visibility = [":friends"],
361    deps = [
362        "//tensorflow/compiler/mlir/tensorflow:dump_graph",
363        "//tensorflow/compiler/xla:parse_flags_from_env",
364        "//tensorflow/core:framework_internal",
365        "//tensorflow/core:lib",
366        "//tensorflow/core:protos_all_cc",
367        "@com_google_absl//absl/base",
368        "@com_google_absl//absl/strings",
369        "@com_google_absl//absl/types:optional",
370    ],
371)
372
373# Header-only version of "flags" library, for linking from the shared object
374# without ODR violations.
375cc_library(
376    name = "flags_headers",
377    hdrs = ["flags.h"],
378    visibility = [":friends"],
379    deps = [
380        "//tensorflow/compiler/mlir/tensorflow:dump_graph",
381        "//tensorflow/compiler/xla:parse_flags_from_env",
382        "//tensorflow/core:framework_internal",
383        "//tensorflow/core:lib",
384        "//tensorflow/core/protobuf:for_core_protos_cc",
385        "@com_google_absl//absl/strings",
386        "@com_google_absl//absl/types:optional",
387    ],
388)
389
390cc_header_only_library(
391    name = "flags_headers_only",
392    features = [
393        "-parse_headers",  # buildifier: disable=no-parse-headers
394    ],
395    deps = [":flags_headers"],
396)
397
398cc_library(
399    name = "common",
400    srcs = [
401        "defs.cc",
402    ],
403    hdrs = [
404        "defs.h",
405    ],
406    visibility = [":friends"],
407)
408
409# Internal targets below this point.
410
411cc_library(
412    name = "xla_launch_util",
413    srcs = ["xla_launch_util.cc"],
414    hdrs = ["xla_launch_util.h"],
415    visibility = [
416        ":internal",
417        # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp.
418        "//learning/brain/tfrt/tf_tpu:__pkg__",
419        "//learning/brain/tfrt/tpu_common:__pkg__",
420    ],
421    deps = [
422        ":common",
423        ":xla_compilation_cache",
424        ":xla_tensor",
425        "//tensorflow/compiler/tf2xla:common",
426        "//tensorflow/compiler/tf2xla:xla_compiler",
427        "//tensorflow/compiler/xla:shape_util",
428        "//tensorflow/compiler/xla:statusor",
429        "//tensorflow/compiler/xla/client:client_library",
430        "//tensorflow/compiler/xla/client:local_client",
431        "//tensorflow/compiler/xla/service:shaped_buffer",
432        "//tensorflow/core:core_cpu_internal",
433        "//tensorflow/core:framework",
434        "//tensorflow/core:framework_internal",
435        "//tensorflow/core:gpu_runtime",
436        "//tensorflow/core:lib",
437        "//tensorflow/core:lib_internal",
438        "//tensorflow/core:protos_all_cc",
439        "//tensorflow/stream_executor:device_memory_allocator",
440        "@com_google_absl//absl/algorithm:container",
441        "@com_google_absl//absl/cleanup",
442        "@com_google_absl//absl/memory",
443    ],
444)
445
446tf_proto_library(
447    name = "xla_compilation_cache_proto",
448    srcs = ["xla_compilation_cache.proto"],
449    cc_api_version = 2,
450    protodeps = tf_additional_all_protos() + ["//tensorflow/compiler/xla/service:hlo_proto"],
451    visibility = ["//visibility:public"],
452)
453
454cc_library(
455    name = "xla_compilation_cache",
456    srcs = ["xla_compilation_cache.cc"],
457    hdrs = ["xla_compilation_cache.h"],
458    copts = tf_copts(),
459    deps = [
460        ":flags",
461        ":xla_activity_listener",
462        ":xla_activity_proto_cc",
463        ":xla_cluster_util",
464        ":xla_compilation_cache_proto_cc",
465        "//tensorflow/compiler/mlir:array_container_utils",
466        "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
467        "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy",
468        "//tensorflow/compiler/tf2xla:common",
469        "//tensorflow/compiler/tf2xla:xla_compiler",
470        "//tensorflow/compiler/tf2xla:xla_context",
471        "//tensorflow/compiler/xla:protobuf_util",
472        "//tensorflow/compiler/xla:status_macros",
473        "//tensorflow/compiler/xla:statusor",
474        "//tensorflow/compiler/xla:util",
475        "//tensorflow/compiler/xla/client:client_library",
476        "//tensorflow/compiler/xla/client:local_client",
477        "//tensorflow/compiler/xla/service:compiler",
478        "//tensorflow/compiler/xla/service:hlo_proto_cc",
479        "//tensorflow/core:core_cpu",
480        "//tensorflow/core:core_cpu_internal",
481        "//tensorflow/core:framework",
482        "//tensorflow/core:lib",
483        "//tensorflow/core:lib_internal",
484        "//tensorflow/core:protos_all_cc",
485        "//tensorflow/core/platform:logging",
486        "//tensorflow/core/tpu:tpu_defs",
487        "@com_google_absl//absl/base",
488        "@com_google_absl//absl/container:flat_hash_map",
489        "@com_google_absl//absl/container:inlined_vector",
490        "@com_google_absl//absl/status",
491        "@com_google_absl//absl/strings",
492        "@com_google_absl//absl/types:optional",
493        "@com_google_absl//absl/types:span",
494        "@com_google_absl//absl/types:variant",
495    ],
496)
497
498tf_cc_test(
499    name = "xla_compilation_cache_test",
500    srcs = [
501        "xla_compilation_cache_test.cc",
502    ],
503    deps = [
504        ":flags",
505        ":xla_compilation_cache",
506        ":xla_cpu_jit",
507        "//tensorflow/compiler/tf2xla:common",
508        "//tensorflow/compiler/xla/client:client_library",
509        "//tensorflow/core:test",
510        "//tensorflow/core:test_main",
511    ],
512)
513
514tf_cc_test(
515    name = "xla_compilation_cache_disable_test",
516    srcs = [
517        "xla_compilation_cache_disable_test.cc",
518    ],
519    deps = [
520        ":flags",
521        ":xla_compilation_cache",
522        ":xla_cpu_jit",
523        "//tensorflow/compiler/tf2xla:common",
524        "//tensorflow/compiler/tf2xla:xla_compiler",
525        "//tensorflow/compiler/xla/client:client_library",
526        "//tensorflow/core:test",
527        "//tensorflow/core:test_main",
528    ],
529)
530
531cc_library(
532    name = "jit_compilation_passes",
533    srcs = ["jit_compilation_pass_registration.cc"],
534    deps = [
535        ":compilation_passes",
536        ":xla_activity_logging_listener",
537        "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
538        "//tensorflow/compiler/tf2xla:mlir_bridge_pass_registration",
539        "//tensorflow/core:core_cpu_internal",
540    ] + tf_jit_compilation_passes_extra_deps(),
541    alwayslink = 1,
542)
543
544cc_library(
545    name = "get_compiler_ir",
546    srcs = ["get_compiler_ir.cc"],
547    hdrs = ["get_compiler_ir.h"],
548    visibility = [":internal"],
549    deps = [
550        ":common",
551        ":compilability_check_util",
552        ":flags",
553        ":xla_device_no_jit_rewrite_registration",
554        ":xla_launch_util",
555        "//tensorflow/compiler/tf2xla:xla_compiler",
556        "//tensorflow/compiler/xla:statusor",
557        "//tensorflow/compiler/xla/client:executable_build_options",
558        "//tensorflow/compiler/xla/client:local_client",
559        "//tensorflow/compiler/xla/service:hlo_graph_dumper",
560        "//tensorflow/core:framework",
561        "//tensorflow/core:lib",
562        "//tensorflow/core/common_runtime:core_cpu_internal",
563        "//tensorflow/core/common_runtime/eager:tensor_handle",
564        "@com_google_absl//absl/memory",
565        "@com_google_absl//absl/strings",
566        "@com_google_absl//absl/strings:str_format",
567        "@com_google_absl//absl/types:span",
568    ],
569    alwayslink = 1,
570)
571
572# Header-only version of "flags" library, for linking from the shared object
573# without ODR violations.
574cc_library(
575    name = "get_compiler_ir_hdrs",
576    textual_hdrs = ["get_compiler_ir.h"],
577    visibility = [":internal"],
578    deps = [
579        "//tensorflow/compiler/xla:statusor",
580        "@com_google_absl//absl/memory",
581        "@com_google_absl//absl/strings",
582        "@com_google_absl//absl/strings:str_format",
583        "@com_google_absl//absl/types:span",
584    ],
585)
586
587cc_header_only_library(
588    name = "get_compiler_ir_hdrs_only",
589    features = [
590        "-parse_headers",  # buildifier: disable=no-parse-headers
591    ],
592    deps = [":get_compiler_ir_hdrs"],
593)
594
595# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
596cc_header_only_library(
597    name = "xla_jit_headers_lib",
598    visibility = ["//visibility:public"],
599    deps = [
600        ":xla_cpu_device",
601        ":xla_cpu_jit",
602        ":xla_gpu_device",
603        ":xla_gpu_jit",
604    ],
605)
606
607cc_library(
608    name = "xla_kernel_creator",
609    srcs = [
610        "xla_kernel_creator.cc",
611        "xla_kernel_creator.h",
612    ],
613    visibility = [
614        ":internal",
615        "//tensorflow/core/common_runtime/eager:__pkg__",
616    ],
617    deps = [
618        ":common",
619        ":compilability_check_util",
620        ":compilation_passes",
621        ":flags",
622        ":jit_compilation_passes",
623        "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
624        "//tensorflow/compiler/tf2xla:mlir_bridge_pass",
625        "//tensorflow/compiler/tf2xla:xla_compiler",
626        "//tensorflow/compiler/tf2xla:xla_op_registry",
627        "//tensorflow/core:core_cpu_internal",
628        "//tensorflow/core:framework",
629        "//tensorflow/core:lib",
630        "//tensorflow/core:protos_all_cc",
631        "@com_google_absl//absl/memory",
632        "@com_google_absl//absl/strings",
633        "@com_google_absl//absl/strings:str_format",
634    ],
635    alwayslink = 1,
636)
637
638tf_cc_test(
639    name = "xla_kernel_creator_test",
640    srcs = [
641        "xla_kernel_creator.h",
642        "xla_kernel_creator_test.cc",
643    ],
644    deps = [
645        ":xla_kernel_creator",
646        "//tensorflow/core:core_cpu_internal",
647        "//tensorflow/core:framework",
648        "//tensorflow/core:lib",
649        "//tensorflow/core:protos_all_cc",
650        "//tensorflow/core:session_options",
651        "//tensorflow/core:test",
652        "//tensorflow/core:test_main",
653        "//tensorflow/core:testlib",
654        "@com_google_absl//absl/memory",
655    ],
656)
657
658cc_library(
659    name = "resource_operation_safety_analysis",
660    srcs = ["resource_operation_safety_analysis.cc"],
661    hdrs = ["resource_operation_safety_analysis.h"],
662    deps = [
663        ":xla_cluster_util",
664        "//tensorflow/compiler/tf2xla:resource_operation_table",
665        "//tensorflow/compiler/xla/service/graphcycles",
666        "//tensorflow/core:framework",
667        "//tensorflow/core:graph",
668        "//tensorflow/core:lib",
669        "//tensorflow/core:protos_all_cc",
670        "@com_google_absl//absl/container:flat_hash_set",
671        "@com_google_absl//absl/memory",
672        "@com_google_absl//absl/strings",
673        "@com_google_absl//absl/types:optional",
674    ],
675)
676
677tf_cc_test(
678    name = "resource_operation_safety_analysis_test",
679    srcs = ["resource_operation_safety_analysis_test.cc"],
680    deps = [
681        ":common",
682        ":resource_operation_safety_analysis",
683        "//tensorflow/cc:cc_ops",
684        "//tensorflow/cc:cc_ops_internal",
685        "//tensorflow/cc:function_ops",
686        "//tensorflow/cc:functional_ops",
687        "//tensorflow/cc:ops",
688        "//tensorflow/cc:resource_variable_ops",
689        "//tensorflow/cc:sendrecv_ops",
690        "//tensorflow/compiler/jit/kernels:xla_ops",
691        "//tensorflow/compiler/tf2xla:xla_compiler",
692        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
693        "//tensorflow/core:core_cpu",
694        "//tensorflow/core:framework",
695        "//tensorflow/core:framework_internal",
696        "//tensorflow/core:graph",
697        "//tensorflow/core:lib",
698        "//tensorflow/core:test",
699        "//tensorflow/core:test_main",
700        "//tensorflow/core:testlib",
701        "@com_google_absl//absl/strings",
702    ],
703)
704
705cc_library(
706    name = "shape_inference",
707    srcs = ["shape_inference.cc"],
708    hdrs = ["shape_inference.h"],
709    visibility = [":friends"],
710    deps = [
711        ":shape_inference_helpers",
712        "//tensorflow/compiler/xla:statusor",
713        "//tensorflow/core:core_cpu_internal",
714        "//tensorflow/core:framework",
715        "//tensorflow/core:graph",
716        "//tensorflow/core:lib",
717        "//tensorflow/core:protos_all_cc",
718    ],
719)
720
721cc_library(
722    name = "test_util",
723    testonly = 1,
724    srcs = ["test_util.cc"],
725    hdrs = ["test_util.h"],
726    deps = [
727        ":shape_inference",
728        "//tensorflow/compiler/xla:status_macros",
729        "//tensorflow/core:core_cpu",
730        "//tensorflow/core:framework",
731        "//tensorflow/core:lib",
732    ],
733)
734
735tf_cc_test(
736    name = "shape_inference_test",
737    srcs = ["shape_inference_test.cc"],
738    deps = [
739        ":shape_inference",
740        ":test_util",
741        "//tensorflow/cc:cc_ops",
742        "//tensorflow/cc:cc_ops_internal",
743        "//tensorflow/cc:ops",
744        "//tensorflow/cc:resource_variable_ops",
745        "//tensorflow/core:framework",
746        "//tensorflow/core:ops",
747        "//tensorflow/core:test",
748        "//tensorflow/core:test_main",
749        "//tensorflow/core/kernels:constant_op",
750    ],
751)
752
753cc_library(
754    name = "encapsulate_util",
755    srcs = ["encapsulate_util.cc"],
756    hdrs = ["encapsulate_util.h"],
757    deps = [
758        ":shape_inference",
759        "//tensorflow/compiler/tf2xla:tf2xla_util",
760        "//tensorflow/core:framework",
761        "//tensorflow/core:graph",
762        "//tensorflow/core:protos_all_cc",
763        "//tensorflow/stream_executor/lib",
764        "@com_google_absl//absl/container:flat_hash_map",
765        "@com_google_absl//absl/container:flat_hash_set",
766        "@com_google_absl//absl/strings",
767        "@com_google_absl//absl/types:optional",
768    ],
769)
770
771tf_cc_test(
772    name = "encapsulate_util_test",
773    srcs = ["encapsulate_util_test.cc"],
774    deps = [
775        ":encapsulate_util",
776        "//tensorflow/cc:cc_ops",
777        "//tensorflow/cc:scope",
778        "//tensorflow/core:framework",
779        "//tensorflow/core:ops",
780        "//tensorflow/core:protos_all_cc",
781        "//tensorflow/core:test",
782        "//tensorflow/core:test_main",
783    ],
784)
785
786cc_library(
787    name = "compilation_passes",
788    srcs = [
789        "build_xla_ops_pass.cc",
790        "clone_constants_for_better_clustering.cc",
791        "cluster_scoping_pass.cc",
792        "deadness_analysis.cc",
793        "deadness_analysis_internal.h",
794        "encapsulate_subgraphs_pass.cc",
795        "encapsulate_xla_computations_pass.cc",
796        "extract_outside_compilation_pass.cc",
797        "force_xla_constants_on_host_pass.cc",
798        "increase_dynamism_for_auto_jit_pass.cc",
799        "introduce_floating_point_jitter_pass.cc",
800        "mark_for_compilation_pass.cc",
801        "mark_for_compilation_pass_test_helper.cc",
802        "partially_decluster_pass.cc",
803        "report_clustering_info_pass.cc",
804    ],
805    hdrs = [
806        "build_xla_ops_pass.h",
807        "clone_constants_for_better_clustering.h",
808        "cluster_scoping_pass.h",
809        "deadness_analysis.h",
810        "encapsulate_subgraphs_pass.h",
811        "encapsulate_xla_computations_pass.h",
812        "extract_outside_compilation_pass.h",
813        "force_xla_constants_on_host_pass.h",
814        "increase_dynamism_for_auto_jit_pass.h",
815        "introduce_floating_point_jitter_pass.h",
816        "mark_for_compilation_pass.h",
817        "mark_for_compilation_pass_test_helper.h",
818        "partially_decluster_pass.h",
819        "report_clustering_info_pass.h",
820    ],
821    visibility = [
822        ":internal",
823        "//tensorflow/core/tfrt/utils:__pkg__",
824        "//third_party/cloud_tpu/inference_converter:__pkg__",
825    ],
826    deps = [
827        "compilability_check_util",
828        ":common",
829        ":device_util",
830        ":encapsulate_util",
831        ":flags",
832        ":resource_operation_safety_analysis",
833        ":shape_inference_helpers",
834        ":xla_activity_listener",
835        ":xla_cluster_util",
836        "//tensorflow/cc:cc_ops",
837        "//tensorflow/cc:functional_ops",
838        "//tensorflow/cc:ops",
839        "//tensorflow/cc:scope",
840        "//tensorflow/cc:scope_internal",
841        "//tensorflow/compiler/jit/ops:xla_ops",
842        "//tensorflow/compiler/tf2xla:resource_operation_table",
843        "//tensorflow/compiler/tf2xla:side_effect_util",
844        "//tensorflow/compiler/tf2xla:tf2xla_util",
845        "//tensorflow/compiler/tf2xla:xla_compiler",
846        "//tensorflow/compiler/tf2xla:xla_op_registry",
847        "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
848        "//tensorflow/compiler/tf2xla/cc:xla_ops",
849        "//tensorflow/compiler/xla:status_macros",
850        "//tensorflow/compiler/xla:statusor",
851        "//tensorflow/compiler/xla:union_find",
852        "//tensorflow/compiler/xla:util",
853        "//tensorflow/compiler/xla:xla_data_proto_cc",
854        "//tensorflow/compiler/xla/service/graphcycles",
855        "//tensorflow/core:core_cpu",
856        "//tensorflow/core:core_cpu_internal",
857        "//tensorflow/core:framework",
858        "//tensorflow/core:lib",
859        "//tensorflow/core:lib_internal",
860        "//tensorflow/core:protos_all_cc",
861        "//tensorflow/core/framework:bounds_check",
862        "//tensorflow/stream_executor/lib",
863        "@com_google_absl//absl/algorithm:container",
864        "@com_google_absl//absl/base",
865        "@com_google_absl//absl/container:flat_hash_map",
866        "@com_google_absl//absl/container:flat_hash_set",
867        "@com_google_absl//absl/container:inlined_vector",
868        "@com_google_absl//absl/memory",
869        "@com_google_absl//absl/strings",
870        "@com_google_absl//absl/types:optional",
871    ],
872)
873
874cc_library(
875    name = "xla_cluster_util",
876    srcs = ["xla_cluster_util.cc"],
877    hdrs = ["xla_cluster_util.h"],
878    deps = [
879        ":flags",
880        ":xla_activity_proto_cc",
881        "//tensorflow/compiler/xla:status_macros",
882        "//tensorflow/compiler/xla:statusor",
883        "//tensorflow/compiler/xla/service/graphcycles",
884        "//tensorflow/core:core_cpu",
885        "//tensorflow/core:framework",
886        "//tensorflow/core:framework_internal",
887        "//tensorflow/core:graph",
888        "//tensorflow/core:lib",
889        "//tensorflow/core:protos_all_cc",
890        "//tensorflow/core/framework:bounds_check",
891        "//tensorflow/stream_executor/lib",
892        "@com_google_absl//absl/algorithm:container",
893        "@com_google_absl//absl/container:flat_hash_map",
894        "@com_google_absl//absl/container:flat_hash_set",
895        "@com_google_absl//absl/container:inlined_vector",
896        "@com_google_absl//absl/strings",
897        "@com_google_absl//absl/types:optional",
898    ],
899)
900
901cc_library(
902    name = "device_util",
903    srcs = ["device_util.cc"],
904    hdrs = ["device_util.h"],
905    deps = [
906        "//tensorflow/compiler/tf2xla:xla_compiler",
907        "//tensorflow/compiler/tf2xla:xla_op_registry",
908        "//tensorflow/compiler/xla:status_macros",
909        "//tensorflow/compiler/xla:statusor",
910        "//tensorflow/core:framework",
911        "@com_google_absl//absl/algorithm:container",
912        "@com_google_absl//absl/container:flat_hash_map",
913        "@com_google_absl//absl/container:flat_hash_set",
914        "@com_google_absl//absl/strings",
915        "@com_google_absl//absl/types:span",
916    ],
917)
918
919tf_cc_test(
920    name = "device_util_test",
921    srcs = ["device_util_test.cc"],
922    deps = [
923        ":device_util",
924        "//tensorflow/core:test",
925        "//tensorflow/core:test_main",
926        "//tensorflow/core:testlib",
927    ],
928)
929
930tf_cc_test(
931    name = "deadness_analysis_test",
932    size = "small",
933    srcs = [
934        "deadness_analysis_internal.h",
935        "deadness_analysis_test.cc",
936    ],
937    deps = [
938        ":common",
939        ":compilation_passes",
940        "//tensorflow/cc:cc_ops",
941        "//tensorflow/cc:cc_ops_internal",
942        "//tensorflow/cc:function_ops",
943        "//tensorflow/cc:ops",
944        "//tensorflow/cc:sendrecv_ops",
945        "//tensorflow/compiler/jit/kernels:xla_ops",
946        "//tensorflow/compiler/tf2xla:xla_compiler",
947        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
948        "//tensorflow/core:core_cpu",
949        "//tensorflow/core:framework",
950        "//tensorflow/core:framework_internal",
951        "//tensorflow/core:graph",
952        "//tensorflow/core:lib",
953        "//tensorflow/core:test",
954        "//tensorflow/core:test_main",
955        "//tensorflow/core:testlib",
956        "@com_google_absl//absl/container:flat_hash_map",
957    ],
958)
959
960cc_library(
961    name = "compilation_passes_test_main",
962    testonly = True,
963    srcs = ["compilation_passes_test_main.cc"],
964    visibility = ["//visibility:public"],
965    deps = [
966        ":flags",
967        "//tensorflow/core:lib",
968        "//tensorflow/core:test",
969        "@com_google_absl//absl/strings",
970    ],
971)
972
973tf_cc_test(
974    name = "compilation_passes_test",
975    size = "small",
976    srcs = [
977        "build_xla_ops_pass_test.cc",
978        "clone_constants_for_better_clustering_test.cc",
979        "cluster_scoping_pass_test.cc",
980        "encapsulate_subgraphs_pass_test.cc",
981        "encapsulate_xla_computations_pass_test.cc",
982        "extract_outside_compilation_pass_test.cc",
983        "force_xla_constants_on_host_pass_test.cc",
984        "increase_dynamism_for_auto_jit_pass_test.cc",
985        "introduce_floating_point_jitter_pass_internal.h",
986        "introduce_floating_point_jitter_pass_test.cc",
987        "mark_for_compilation_pass_test.cc",
988        "partially_decluster_pass_test.cc",
989        "rearrange_function_argument_pass_test.cc",
990    ],
991    tags = [
992        # TODO(b/141643254) Re-enable msan after fixing
993        # use-of-uninitialized-value error.
994        "nomsan",
995    ] + tf_cuda_tests_tags(),
996    deps = [
997        ":common",
998        ":compilability_check_util",
999        ":compilation_passes",
1000        ":compilation_passes_test_main",
1001        ":encapsulate_util",
1002        ":flags",
1003        ":node_matchers",
1004        ":test_util",
1005        ":xla_cluster_util",
1006        ":xla_cpu_device",
1007        ":xla_gpu_device",
1008        "//tensorflow/cc:cc_ops",
1009        "//tensorflow/cc:cc_ops_internal",
1010        "//tensorflow/cc:function_ops",
1011        "//tensorflow/cc:functional_ops",
1012        "//tensorflow/cc:ops",
1013        "//tensorflow/cc:resource_variable_ops",
1014        "//tensorflow/cc:scope",
1015        "//tensorflow/cc:sendrecv_ops",
1016        "//tensorflow/compiler/jit/kernels:xla_ops",
1017        "//tensorflow/compiler/tf2xla:rearrange_function_argument",
1018        "//tensorflow/compiler/tf2xla:side_effect_util",
1019        "//tensorflow/compiler/tf2xla:test_util",
1020        "//tensorflow/compiler/tf2xla:xla_compiler",
1021        "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
1022        "//tensorflow/compiler/tf2xla/cc:xla_ops",
1023        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
1024        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
1025        "//tensorflow/compiler/xla:test",
1026        "//tensorflow/core:all_kernels",
1027        "//tensorflow/core:core_cpu",
1028        "//tensorflow/core:framework",
1029        "//tensorflow/core:framework_internal",
1030        "//tensorflow/core:lib",
1031        "//tensorflow/core:protos_all_cc",
1032        "//tensorflow/core:session_options",
1033        "//tensorflow/core:test",
1034        "//tensorflow/core:testlib",
1035        "@com_google_absl//absl/container:flat_hash_map",
1036        "@com_google_absl//absl/memory",
1037        "@com_google_absl//absl/strings",
1038        "@com_google_absl//absl/types:span",
1039    ],
1040)
1041
1042tf_cc_test(
1043    name = "xla_cluster_util_test",
1044    size = "small",
1045    srcs = [
1046        "xla_cluster_util_test.cc",
1047    ],
1048    deps = [
1049        ":common",
1050        ":xla_cluster_util",
1051        "//tensorflow/cc:cc_ops",
1052        "//tensorflow/cc:cc_ops_internal",
1053        "//tensorflow/cc:function_ops",
1054        "//tensorflow/cc:functional_ops",
1055        "//tensorflow/cc:ops",
1056        "//tensorflow/compiler/jit/kernels:xla_ops",
1057        "//tensorflow/compiler/tf2xla:xla_compiler",
1058        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
1059        "//tensorflow/compiler/xla:status_macros",
1060        "//tensorflow/core:core_cpu",
1061        "//tensorflow/core:core_cpu_lib",
1062        "//tensorflow/core:framework",
1063        "//tensorflow/core:framework_internal",
1064        "//tensorflow/core:lib",
1065        "//tensorflow/core:test",
1066        "//tensorflow/core:test_main",
1067        "//tensorflow/core:testlib",
1068        "@com_google_absl//absl/algorithm:container",
1069        "@com_google_absl//absl/strings",
1070    ],
1071)
1072
1073cc_library(
1074    name = "node_matchers",
1075    testonly = True,
1076    srcs = ["node_matchers.cc"],
1077    hdrs = ["node_matchers.h"],
1078    deps = [
1079        "//tensorflow/cc:ops",
1080        "//tensorflow/compiler/xla:test",
1081        "//tensorflow/core:framework",
1082        "//tensorflow/core:graph",
1083        "//tensorflow/core:protos_all_cc",
1084        "@com_google_absl//absl/algorithm:container",
1085        "@com_google_absl//absl/strings",
1086        "@com_google_absl//absl/types:optional",
1087        "@com_google_absl//absl/types:span",
1088    ],
1089)
1090
1091tf_cc_test(
1092    name = "node_matchers_test",
1093    srcs = ["node_matchers_test.cc"],
1094    deps = [
1095        ":node_matchers",
1096        "//tensorflow/cc:cc_ops",
1097        "//tensorflow/cc:cc_ops_internal",
1098        "//tensorflow/cc:ops",
1099        "//tensorflow/core:ops",
1100        "//tensorflow/core:test_main",
1101    ],
1102)
1103
1104cc_library(
1105    name = "compilability_check_util",
1106    srcs = ["compilability_check_util.cc"],
1107    hdrs = ["compilability_check_util.h"],
1108    visibility = [
1109        ":friends",
1110    ],
1111    deps = [
1112        ":common",
1113        ":device_util",
1114        ":flags",
1115        ":resource_operation_safety_analysis",
1116        ":xla_activity_listener",
1117        ":xla_activity_proto_cc",
1118        ":xla_cluster_util",
1119        "//tensorflow/compiler/tf2xla:resource_operation_table",
1120        "//tensorflow/compiler/tf2xla:tf2xla_util",
1121        "//tensorflow/compiler/tf2xla:xla_compiler",
1122        "//tensorflow/compiler/tf2xla:xla_op_registry",
1123        "//tensorflow/compiler/xla:statusor",
1124        "//tensorflow/compiler/xla:union_find",
1125        "//tensorflow/compiler/xla:util",
1126        "//tensorflow/compiler/xla/service/graphcycles",
1127        "//tensorflow/core:core_cpu",
1128        "//tensorflow/core:framework",
1129        "//tensorflow/core:graph",
1130        "//tensorflow/core:lib",
1131        "//tensorflow/core:protos_all_cc",
1132        "@com_google_absl//absl/algorithm:container",
1133        "@com_google_absl//absl/container:flat_hash_map",
1134        "@com_google_absl//absl/container:flat_hash_set",
1135        "@com_google_absl//absl/strings",
1136        "@com_google_absl//absl/types:optional",
1137    ],
1138)
1139
1140tf_cc_test(
1141    name = "compilability_check_util_test",
1142    srcs = ["compilability_check_util_test.cc"],
1143    deps = [
1144        ":compilability_check_util",
1145        ":xla_cpu_device",
1146        ":xla_cpu_jit",
1147        "//tensorflow/cc:cc_ops",
1148        "//tensorflow/cc:function_ops",
1149        "//tensorflow/cc:functional_ops",
1150        "//tensorflow/cc:ops",
1151        "//tensorflow/cc:scope",
1152        "//tensorflow/compiler/tf2xla:test_util",
1153        "//tensorflow/compiler/tf2xla:xla_compiler",
1154        "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
1155        "//tensorflow/compiler/tf2xla/cc:xla_ops",
1156        "//tensorflow/core:core_cpu",
1157        "//tensorflow/core:framework",
1158        "//tensorflow/core:ops",
1159        "//tensorflow/core:protos_all_cc",
1160        "//tensorflow/core:test",
1161        "//tensorflow/core:test_main",
1162        "@com_google_absl//absl/memory",
1163    ],
1164)
1165
1166tf_cc_test(
1167    name = "xla_activity_listener_test",
1168    srcs = ["xla_activity_listener_test.cc"],
1169    deps = [
1170        ":flags",
1171        ":xla_activity_listener",
1172        ":xla_cpu_device",
1173        ":xla_cpu_jit",
1174        "//tensorflow/cc:cc_ops",
1175        "//tensorflow/cc:ops",
1176        "//tensorflow/core:all_kernels",
1177        "//tensorflow/core:core_cpu",
1178        "//tensorflow/core:framework",
1179        "//tensorflow/core:ops",
1180        "//tensorflow/core:test",
1181        "//tensorflow/core/common_runtime:direct_session_internal",
1182        "//tensorflow/core/kernels:cwise_op",
1183        "//tensorflow/core/kernels:matmul_op",
1184        "//tensorflow/core/kernels:partitioned_function_ops",
1185    ],
1186)
1187
1188tf_custom_op_py_library(
1189    name = "xla_ops_py",
1190    kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
1191    visibility = [
1192        ":friends",
1193    ],
1194    deps = [
1195        "//tensorflow/compiler/jit/ops:xla_ops_grad",
1196        "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
1197    ],
1198)
1199
1200cc_library(
1201    name = "xla_activity_listener",
1202    srcs = ["xla_activity_listener.cc"],
1203    hdrs = ["xla_activity_listener.h"],
1204    visibility = ["//visibility:public"],
1205    deps = [
1206        ":xla_activity_proto_cc",
1207        "//tensorflow/core:lib",
1208        "@com_google_absl//absl/synchronization",
1209    ],
1210)
1211
1212tf_proto_library(
1213    name = "xla_activity_proto",
1214    srcs = ["xla_activity.proto"],
1215    cc_api_version = 2,
1216    protodeps = tf_additional_all_protos(),
1217)
1218
1219cc_library(
1220    name = "xla_activity_logging_listener",
1221    srcs = ["xla_activity_logging_listener.cc"],
1222    deps = [
1223        ":xla_activity_listener",
1224        ":xla_activity_proto_cc",
1225        "//tensorflow/core:lib",
1226        "@com_google_absl//absl/memory",
1227    ],
1228    alwayslink = 1,
1229)
1230