• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   ROCm-platform specific StreamExecutor support code.
3
4load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
5load(
6    "//tensorflow/stream_executor:build_defs.bzl",
7    "stream_executor_friends",
8)
9load("//tensorflow:tensorflow.bzl", "tf_copts")
10load("//tensorflow:tensorflow.bzl", "filegroup")
11load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
12load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
13
14package(
15    default_visibility = [":friends"],
16    licenses = ["notice"],  # Apache 2.0
17)
18
19package_group(
20    name = "friends",
21    packages = stream_executor_friends(),
22)
23
24# Filegroup used to collect source files for the dependency check.
25filegroup(
26    name = "c_srcs",
27    data = glob([
28        "**/*.cc",
29        "**/*.h",
30    ]),
31)
32
33cc_library(
34    name = "rocm_diagnostics",
35    srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
36    hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
37    deps = if_rocm_is_configured([
38        "@com_google_absl//absl/container:inlined_vector",
39        "@com_google_absl//absl/strings",
40        "@com_google_absl//absl/strings:str_format",
41        "//tensorflow/stream_executor/gpu:gpu_diagnostics_header",
42        "//tensorflow/stream_executor/lib",
43        "//tensorflow/stream_executor/platform",
44    ]),
45)
46
47cc_library(
48    name = "rocm_driver",
49    srcs = if_rocm_is_configured(["rocm_driver.cc"]),
50    hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]),
51    deps = if_rocm_is_configured([
52        ":rocm_diagnostics",
53        "@com_google_absl//absl/base",
54        "@com_google_absl//absl/container:inlined_vector",
55        "@com_google_absl//absl/strings",
56        "//tensorflow/stream_executor:device_options",
57        "//tensorflow/stream_executor/gpu:gpu_driver_header",
58        "//tensorflow/stream_executor/lib",
59        "//tensorflow/stream_executor/platform",
60        "//tensorflow/stream_executor/platform:dso_loader",
61        "@local_config_rocm//rocm:rocm_headers",
62    ]),
63)
64
65cc_library(
66    name = "rocm_activation",
67    srcs = [],
68    hdrs = if_rocm_is_configured(["rocm_activation.h"]),
69    deps = if_rocm_is_configured([
70        ":rocm_driver",
71        "@local_config_rocm//rocm:rocm_headers",
72        "//tensorflow/stream_executor",
73        "//tensorflow/stream_executor:stream_executor_internal",
74        "//tensorflow/stream_executor/gpu:gpu_activation",
75        "//tensorflow/stream_executor/platform",
76    ]),
77)
78
79cc_library(
80    name = "rocm_event",
81    srcs = if_rocm_is_configured(["rocm_event.cc"]),
82    hdrs = [],
83    deps = if_rocm_is_configured([
84        ":rocm_driver",
85        "//tensorflow/stream_executor:stream_executor_headers",
86        "//tensorflow/stream_executor/gpu:gpu_event_header",
87        "//tensorflow/stream_executor/gpu:gpu_executor_header",
88        "//tensorflow/stream_executor/gpu:gpu_stream_header",
89        "//tensorflow/stream_executor/lib",
90    ]),
91)
92
93cc_library(
94    name = "rocm_gpu_executor",
95    srcs = if_rocm_is_configured(["rocm_gpu_executor.cc"]),
96    hdrs = [],
97    deps = if_rocm_is_configured([
98        ":rocm_diagnostics",
99        ":rocm_driver",
100        ":rocm_event",
101        ":rocm_kernel",
102        ":rocm_platform_id",
103        "@com_google_absl//absl/strings",
104        "//tensorflow/stream_executor:event",
105        "//tensorflow/stream_executor:plugin_registry",
106        "//tensorflow/stream_executor:stream_executor_internal",
107        "//tensorflow/stream_executor:stream_executor_pimpl_header",
108        "//tensorflow/stream_executor:timer",
109        "//tensorflow/stream_executor/gpu:gpu_activation_header",
110        "//tensorflow/stream_executor/gpu:gpu_event",
111        "//tensorflow/stream_executor/gpu:gpu_kernel_header",
112        "//tensorflow/stream_executor/gpu:gpu_stream",
113        "//tensorflow/stream_executor/gpu:gpu_timer",
114        "//tensorflow/stream_executor/lib",
115        "//tensorflow/stream_executor/platform",
116        "//tensorflow/stream_executor/platform:dso_loader",
117    ]),
118    alwayslink = True,
119)
120
121cc_library(
122    name = "rocm_kernel",
123    srcs = if_rocm_is_configured(["rocm_kernel.cc"]),
124    hdrs = [],
125    visibility = ["//visibility:public"],
126    deps = if_rocm_is_configured([
127        "//tensorflow/stream_executor/gpu:gpu_kernel_header",
128    ]),
129    alwayslink = True,
130)
131
132cc_library(
133    name = "rocm_platform",
134    srcs = if_rocm_is_configured(["rocm_platform.cc"]),
135    hdrs = if_rocm_is_configured(["rocm_platform.h"]),
136    visibility = ["//visibility:public"],
137    deps = if_rocm_is_configured([
138        ":rocm_driver",
139        ":rocm_gpu_executor",
140        ":rocm_platform_id",
141        "@com_google_absl//absl/base",
142        "@com_google_absl//absl/memory",
143        "//tensorflow/core:lib",
144        "//tensorflow/stream_executor",  # buildcleaner: keep
145        "//tensorflow/stream_executor:executor_cache",
146        "//tensorflow/stream_executor:multi_platform_manager",
147        "//tensorflow/stream_executor:stream_executor_pimpl_header",
148        "//tensorflow/stream_executor/lib",
149        "//tensorflow/stream_executor/platform",
150    ]),
151    alwayslink = True,  # Registers itself with the MultiPlatformManager.
152)
153
154cc_library(
155    name = "rocm_platform_id",
156    srcs = ["rocm_platform_id.cc"],
157    hdrs = ["rocm_platform_id.h"],
158    deps = ["//tensorflow/stream_executor:platform"],
159)
160
161cc_library(
162    name = "rocblas_if_static",
163    deps = if_static([
164        "@local_config_rocm//rocm:rocblas",
165    ]),
166)
167
168cc_library(
169    name = "rocblas_plugin",
170    srcs = if_rocm_is_configured(["rocm_blas.cc"]),
171    hdrs = if_rocm_is_configured(["rocm_blas.h"]),
172    visibility = ["//visibility:public"],
173    deps = if_rocm_is_configured([
174        ":rocblas_if_static",
175        ":rocm_gpu_executor",
176        ":rocm_platform_id",
177        "//third_party/eigen3",
178        "//tensorflow/core:lib",
179        "//tensorflow/core:lib_internal",
180        "//tensorflow/stream_executor",
181        "//tensorflow/stream_executor:event",
182        "//tensorflow/stream_executor:host_or_device_scalar",
183        "//tensorflow/stream_executor:plugin_registry",
184        "//tensorflow/stream_executor:scratch_allocator",
185        "//tensorflow/stream_executor:timer",
186        "//tensorflow/stream_executor/gpu:gpu_activation",
187        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
188        "//tensorflow/stream_executor/gpu:gpu_stream_header",
189        "//tensorflow/stream_executor/gpu:gpu_timer_header",
190        "//tensorflow/stream_executor/lib",
191        "//tensorflow/stream_executor/platform",
192        "//tensorflow/stream_executor/platform:dso_loader",
193        "@com_google_absl//absl/strings",
194        "@local_config_rocm//rocm:rocm_headers",
195    ]),
196    alwayslink = True,
197)
198
199cc_library(
200    name = "rocfft_if_static",
201    deps = if_static([
202        "@local_config_rocm//rocm:rocfft",
203    ]),
204)
205
206cc_library(
207    name = "rocfft_plugin",
208    srcs = if_rocm_is_configured(["rocm_fft.cc"]),
209    hdrs = if_rocm_is_configured(["rocm_fft.h"]),
210    visibility = ["//visibility:public"],
211    deps = if_rocm_is_configured([
212        ":rocfft_if_static",
213        ":rocm_platform_id",
214        "//tensorflow/stream_executor:event",
215        "//tensorflow/stream_executor:fft",
216        "//tensorflow/stream_executor:plugin_registry",
217        "//tensorflow/stream_executor:scratch_allocator",
218        "//tensorflow/stream_executor/gpu:gpu_activation",
219        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
220        "//tensorflow/stream_executor/gpu:gpu_executor_header",
221        "//tensorflow/stream_executor/gpu:gpu_stream_header",
222        "//tensorflow/stream_executor/gpu:gpu_kernel_header",
223        "//tensorflow/stream_executor/lib",
224        "//tensorflow/stream_executor/platform",
225        "//tensorflow/stream_executor/platform:dso_loader",
226        "@local_config_rocm//rocm:rocm_headers",
227    ]),
228    alwayslink = True,
229)
230
231cc_library(
232    name = "miopen_if_static",
233    deps = if_static([
234        "@local_config_rocm//rocm:miopen",
235    ]),
236)
237
238cc_library(
239    name = "miopen_plugin",
240    srcs = if_rocm_is_configured(["rocm_dnn.cc"]),
241    hdrs = if_rocm_is_configured(["rocm_dnn.h"]),
242    copts = [
243        # STREAM_EXECUTOR_CUDNN_WRAP would fail on Clang with the default
244        # setting of template depth 256
245        "-ftemplate-depth-512",
246    ],
247    visibility = ["//visibility:public"],
248    deps = if_rocm_is_configured([
249        ":miopen_if_static",
250        ":rocm_diagnostics",
251        ":rocm_driver",
252        ":rocm_gpu_executor",
253        ":rocm_platform_id",
254        "//third_party/eigen3",
255        "//tensorflow/core:lib",
256        "//tensorflow/core:lib_internal",
257        "//tensorflow/stream_executor:dnn",
258        "//tensorflow/stream_executor:event",
259        "//tensorflow/stream_executor:plugin_registry",
260        "//tensorflow/stream_executor:scratch_allocator",
261        "//tensorflow/stream_executor:stream_executor_pimpl_header",
262        "//tensorflow/stream_executor:temporary_device_memory",
263        "//tensorflow/stream_executor/gpu:gpu_activation_header",
264        "//tensorflow/stream_executor/gpu:gpu_stream_header",
265        "//tensorflow/stream_executor/gpu:gpu_timer_header",
266        "//tensorflow/stream_executor/lib",
267        "//tensorflow/stream_executor/platform",
268        "//tensorflow/stream_executor/platform:dso_loader",
269        "@com_google_absl//absl/strings",
270        "@local_config_rocm//rocm:rocm_headers",
271    ]),
272    alwayslink = True,
273)
274
275cc_library(
276    name = "hiprand_if_static",
277    deps = if_static([
278        "@local_config_rocm//rocm:hiprand",
279    ]),
280)
281
282cc_library(
283    name = "rocrand_plugin",
284    srcs = if_rocm_is_configured(["rocm_rng.cc"]),
285    hdrs = if_rocm_is_configured([]),
286    deps = if_rocm_is_configured([
287        ":hiprand_if_static",
288        ":rocm_gpu_executor",
289        ":rocm_platform_id",
290        "@local_config_rocm//rocm:rocm_headers",
291        "//tensorflow/stream_executor:event",
292        "//tensorflow/stream_executor:plugin_registry",
293        "//tensorflow/stream_executor:rng",
294        "//tensorflow/stream_executor/gpu:gpu_activation_header",
295        "//tensorflow/stream_executor/gpu:gpu_helpers_header",
296        "//tensorflow/stream_executor/gpu:gpu_executor_header",
297        "//tensorflow/stream_executor/gpu:gpu_rng_header",
298        "//tensorflow/stream_executor/gpu:gpu_stream_header",
299        "//tensorflow/stream_executor/lib",
300        "//tensorflow/stream_executor/platform",
301        "//tensorflow/stream_executor/platform:dso_loader",
302    ]),
303    alwayslink = True,
304)
305
306cc_library(
307    name = "hipsparse_if_static",
308    deps = if_static([
309        "@local_config_rocm//rocm:hipsparse",
310    ]),
311)
312
313cc_library(
314    name = "hipsparse_wrapper",
315    srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
316    hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]),
317    deps = if_rocm_is_configured([
318        ":hipsparse_if_static",
319        ":rocm_gpu_executor",
320        ":rocm_platform_id",
321        "@local_config_rocm//rocm:rocm_headers",
322        "//tensorflow/stream_executor/lib",
323        "//tensorflow/stream_executor/platform",
324        "//tensorflow/stream_executor/platform:dso_loader",
325    ]),
326    alwayslink = True,
327)
328
329cc_library(
330    name = "all_runtime",
331    copts = tf_copts(),
332    visibility = ["//visibility:public"],
333    deps = if_rocm_is_configured([
334        ":miopen_plugin",
335        ":rocfft_plugin",
336        ":rocblas_plugin",
337        ":rocrand_plugin",
338        ":rocm_driver",
339        ":rocm_platform",
340    ]),
341    alwayslink = 1,
342)
343