• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# buildifier: disable=same-origin-load
2load(
3    "//tensorflow/stream_executor:build_defs.bzl",
4    "if_gpu_is_configured",
5)
6
7# buildifier: disable=same-origin-load
8load(
9    "//tensorflow:tensorflow.bzl",
10    "tf_cc_binary",
11)
12load(
13    "//tensorflow/tsl/platform/default:cuda_build_defs.bzl",
14    "if_cuda_is_configured",
15)
16load(
17    "@local_config_rocm//rocm:build_defs.bzl",
18    "if_rocm_is_configured",
19)
20load(
21    "//tensorflow/core/platform:build_config.bzl",
22    "if_llvm_aarch64_available",
23    "if_llvm_system_z_available",
24    "tf_proto_library",
25)
26
27package(
28    default_visibility = [":friends"],
29    licenses = ["notice"],
30)
31
32package_group(
33    name = "friends",
34    packages = [
35        # Edge TPU compiler needs to use some compiler passes from kernel_gen.
36        "//platforms/darwinn/compiler/...",
37        "//tensorflow/compiler/...",
38        "//tensorflow/core/kernels/mlir_generated/...",
39    ],
40)
41
42cc_library(
43    name = "kernel_creator",
44    srcs = ["kernel_creator.cc"],
45    hdrs = ["kernel_creator.h"],
46    copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
47    deps = [
48        "//tensorflow/compiler/mlir/tensorflow",
49        "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
50        "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:bufferize",
51        "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:gpu_passes",
52        "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
53        "//tensorflow/compiler/mlir/xla:xla_legalize_tf_no_fallback",
54        "//tensorflow/compiler/xla/mlir_hlo",
55        "//tensorflow/compiler/xla/mlir_hlo:all_passes",
56        "//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo_pass",
57        "//tensorflow/compiler/xla/mlir_hlo:gpu_passes",
58        "//tensorflow/compiler/xla/mlir_hlo:shape_simplification",
59        "//tensorflow/core:lib",
60        "@llvm-project//llvm:Support",
61        "@llvm-project//mlir:AffineToStandard",
62        "@llvm-project//mlir:ArithmeticDialect",
63        "@llvm-project//mlir:ArithmeticTransforms",
64        "@llvm-project//mlir:BufferizationTransforms",
65        "@llvm-project//mlir:ComplexToStandard",
66        "@llvm-project//mlir:FuncDialect",
67        "@llvm-project//mlir:FuncToLLVM",
68        "@llvm-project//mlir:GPUDialect",
69        "@llvm-project//mlir:GPUToGPURuntimeTransforms",
70        "@llvm-project//mlir:GPUToNVVMTransforms",
71        "@llvm-project//mlir:GPUTransforms",
72        "@llvm-project//mlir:IR",
73        "@llvm-project//mlir:LLVMIRTransforms",
74        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
75        "@llvm-project//mlir:LinalgTransforms",
76        "@llvm-project//mlir:MemRefTransforms",
77        "@llvm-project//mlir:NVVMToLLVMIRTranslation",
78        "@llvm-project//mlir:Parser",
79        "@llvm-project//mlir:Pass",
80        "@llvm-project//mlir:ROCDLToLLVMIRTranslation",
81        "@llvm-project//mlir:ReconcileUnrealizedCasts",
82        "@llvm-project//mlir:SCFDialect",
83        "@llvm-project//mlir:SCFToControlFlow",
84        "@llvm-project//mlir:SCFToGPU",
85        "@llvm-project//mlir:SCFTransforms",
86        "@llvm-project//mlir:ShapeToStandard",
87        "@llvm-project//mlir:Transforms",
88        "@llvm-project//mlir:VectorToLLVM",
89    ],
90)
91
92tf_cc_binary(
93    name = "tf_to_kernel",
94    srcs = ["tf_to_kernel.cc"],
95    visibility = [
96        "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__",
97        "//tensorflow/core/kernels/mlir_generated:__pkg__",
98    ],
99    deps = [
100        ":kernel_creator",
101        "//tensorflow/compiler/mlir:init_mlir",
102        "//tensorflow/compiler/mlir/tensorflow",
103        "//tensorflow/core:lib",
104        "//tensorflow/stream_executor/lib",
105        "@com_google_absl//absl/strings",
106        "@llvm-project//llvm:Analysis",
107        "@llvm-project//llvm:ARMCodeGen",  # fixdeps: keep
108        "@llvm-project//llvm:CodeGen",
109        "@llvm-project//llvm:Core",
110        "@llvm-project//llvm:MC",
111        "@llvm-project//llvm:PowerPCCodeGen",  # fixdeps: keep
112        "@llvm-project//llvm:Support",
113        "@llvm-project//llvm:Target",
114        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
115        "@llvm-project//llvm:X86Disassembler",  # fixdeps: keep
116        "@llvm-project//mlir:ExecutionEngineUtils",
117        "@llvm-project//mlir:Pass",
118        "@llvm-project//mlir:LLVMToLLVMIRTranslation",
119        "@llvm-project//mlir:ToLLVMIRTranslation",
120    ] + if_llvm_system_z_available([
121        "@llvm-project//llvm:SystemZCodeGen",  # fixdeps: keep
122    ]) + if_llvm_aarch64_available([
123        "@llvm-project//llvm:AArch64CodeGen",  # fixdeps: keep
124    ]),
125)
126
127tf_cc_binary(
128    name = "kernel-gen-opt",
129    srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"],
130    visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__subpackages__"],
131    deps = [
132        "//tensorflow/compiler/mlir/tensorflow",
133        "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
134        "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:gpu_passes",
135        "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
136        "//tensorflow/compiler/xla/mlir_hlo:all_passes",
137        "//tensorflow/compiler/xla/mlir_hlo:gml_st",
138        "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration",
139        "@llvm-project//llvm:Support",
140        "@llvm-project//mlir:AllPassesAndDialects",
141        "@llvm-project//mlir:IR",
142        "@llvm-project//mlir:MlirOptLib",
143        "@llvm-project//mlir:Pass",
144        "@llvm-project//mlir:Support",
145    ],
146)
147
148exports_files(["tf_framework_c_interface.h"])
149
150cc_library(
151    name = "tf_framework_c_interface",
152    srcs = ["tf_framework_c_interface.cc"],
153    hdrs = ["tf_framework_c_interface.h"],
154    copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]),
155    deps = [
156        ":compile_cache_item_proto_cc",
157        ":kernel_creator",
158        ":tf_gpu_runtime_wrappers",
159        ":tf_jit_cache",
160        "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
161        "//tensorflow/core:framework",
162        "//tensorflow/core:lib",
163        "//tensorflow/core/platform:strcat",
164        "//tensorflow/stream_executor:stream_header",
165        "@llvm-project//llvm:Support",
166        "@llvm-project//mlir:ExecutionEngine",
167        "@llvm-project//mlir:ExecutionEngineUtils",
168        "@llvm-project//mlir:Parser",
169        "@llvm-project//mlir:mlir_runner_utils",
170    ],
171)
172
173cc_library(
174    name = "tf_jit_cache",
175    srcs = ["tf_jit_cache.cc"],
176    hdrs = ["tf_jit_cache.h"],
177    deps = [
178        "//tensorflow/core:framework",
179        "@com_google_absl//absl/container:flat_hash_map",
180        "@llvm-project//mlir:ExecutionEngine",
181    ],
182)
183
184cc_library(
185    name = "tf_gpu_runtime_wrappers",
186    srcs = if_gpu_is_configured([
187        "tf_gpu_runtime_wrappers.cc",
188    ]),
189    hdrs =
190        if_gpu_is_configured([
191            "tf_gpu_runtime_wrappers.h",
192        ]),
193    copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured([
194        "-DTENSORFLOW_USE_ROCM=1",
195    ]),
196    deps = [
197        "@llvm-project//mlir:mlir_runner_utils",
198        "@com_google_absl//absl/container:flat_hash_map",
199        "@com_google_absl//absl/strings",
200        "//tensorflow/core:framework",
201        "//tensorflow/core/platform:logging",
202        "//tensorflow/core/platform:mutex",
203        "//tensorflow/stream_executor:stream_header",
204    ] + if_cuda_is_configured([
205        "@local_config_cuda//cuda:cuda_headers",
206        "//tensorflow/tsl/platform/default/build_config:stream_executor_cuda",
207    ]) + if_rocm_is_configured([
208        "@local_config_rocm//rocm:rocm_headers",
209        "//tensorflow/tsl/platform/default/build_config:stream_executor_rocm",
210    ]),
211)
212
213tf_proto_library(
214    name = "compile_cache_item_proto",
215    srcs = ["compile_cache_item.proto"],
216    cc_api_version = 2,
217)
218