• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow:tensorflow.bzl", "filegroup")
2load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
3load(
4    "//tensorflow:tensorflow.bzl",
5    "tf_cc_test",
6)
7
8package(
9    licenses = ["notice"],
10)
11
12# Currently pybind extension shared objects must use only C API headers since
13# the C API has static initializers duplicated in the Python bindings. So we
14# need a second rule that omits .cc files, in
15# tensorflow/python:_pywrap_parallel_device.
16filegroup(
17    name = "lib_headers",
18    srcs = ["parallel_device_lib.h"],
19)
20
21filegroup(
22    name = "lib_sources",
23    srcs = ["parallel_device_lib.cc"],
24)
25
26filegroup(
27    name = "device_headers",
28    srcs = ["parallel_device.h"],
29)
30
31filegroup(
32    name = "device_sources",
33    srcs = ["parallel_device.cc"],
34)
35
36filegroup(
37    name = "headers",
38    srcs = [
39        ":device_headers",
40        ":lib_headers",
41    ],
42    visibility = ["//tensorflow/python:__pkg__"],
43)
44
45filegroup(
46    name = "sources",
47    srcs = [
48        ":device_sources",
49        ":lib_sources",
50    ],
51    visibility = ["//tensorflow/python:__pkg__"],
52)
53
54cc_library(
55    name = "parallel_device",
56    srcs = [":device_sources"],
57    hdrs = [":device_headers"],
58    visibility = ["//tensorflow:internal"],
59    deps = [
60        ":parallel_device_lib",
61        "//tensorflow/c:c_api",
62        "//tensorflow/c:tf_status_helper",
63        "//tensorflow/c/eager:c_api",
64        "//tensorflow/c/eager:c_api_experimental",
65        "//tensorflow/c/eager:tfe_tensorhandle_internal",
66        "@com_google_absl//absl/strings",
67        "@com_google_absl//absl/types:optional",
68        "@com_google_absl//absl/types:variant",
69    ],
70)
71
72cc_library(
73    name = "parallel_device_lib",
74    srcs = [":lib_sources"],
75    hdrs = [":lib_headers"],
76    visibility = ["//tensorflow:internal"],
77    deps = [
78        "//tensorflow/c:c_api",
79        "//tensorflow/c:tf_status_internal",
80        "//tensorflow/c/eager:c_api",
81        "//tensorflow/c/eager:c_api_experimental",
82        "//tensorflow/c/eager:tfe_cancellation_manager_internal",
83        "//tensorflow/c/eager:tfe_op_internal",
84        "//tensorflow/c/eager:tfe_tensorhandle_internal",
85        "//tensorflow/core:framework",
86        "//tensorflow/core:lib",
87        "@com_google_absl//absl/types:optional",
88        "@com_google_absl//absl/types:span",
89        "@com_google_absl//absl/types:variant",
90    ],
91)
92
93tf_cc_test(
94    name = "parallel_device_lib_test",
95    srcs = ["parallel_device_lib_test.cc"],
96    deps = [
97        ":parallel_device_lib",
98        ":parallel_device_testlib",
99        "//tensorflow/c:c_api",
100        "//tensorflow/c:c_api_experimental",
101        "//tensorflow/c/eager:c_api",
102        "//tensorflow/c/eager:c_api_experimental",
103        "//tensorflow/c/eager:tfe_context_internal",
104        "//tensorflow/core:framework",
105        "//tensorflow/core:protos_all_cc",
106        "//tensorflow/core:test",
107        "//tensorflow/core:test_main",
108        "//tensorflow/core/common_runtime/eager:context",
109    ],
110)
111
112cc_library(
113    name = "parallel_device_testlib",
114    testonly = 1,
115    srcs = ["parallel_device_testlib.cc"],
116    hdrs = ["parallel_device_testlib.h"],
117    deps = [
118        ":parallel_device",
119        ":parallel_device_lib",
120        "//tensorflow/c:c_api",
121        "//tensorflow/c:c_api_experimental",
122        "//tensorflow/c/eager:c_api",
123        "//tensorflow/c/eager:c_api_experimental",
124        "//tensorflow/core:test",
125        "//tensorflow/core:test_main",
126    ],
127)
128
129tf_cc_test(
130    name = "parallel_device_test",
131    srcs = ["parallel_device_test.cc"],
132    deps = [
133        ":parallel_device",
134        ":parallel_device_testlib",
135        "//tensorflow/c:c_api",
136        "//tensorflow/c:c_api_experimental",
137        "//tensorflow/c:tf_status_internal",
138        "//tensorflow/c/eager:c_api",
139        "//tensorflow/c/eager:c_api_experimental",
140        "//tensorflow/c/eager:immediate_execution_tensor_handle",
141        "//tensorflow/c/eager:tfe_tensorhandle_internal",
142        "//tensorflow/core:protos_all_cc",
143        "//tensorflow/core:test",
144        "//tensorflow/core:test_main",
145    ],
146)
147
148tf_cc_test(
149    name = "parallel_device_remote_test",
150    srcs = ["parallel_device_remote_test.cc"],
151    # TODO(b/136478427): Enable global heap checking when servers shut down
152    # cleanly.
153    args = ["--heap_check="],
154    deps = [
155        ":parallel_device",
156        ":parallel_device_testlib",
157        "//tensorflow/c:c_api",
158        "//tensorflow/c:c_api_experimental",
159        "//tensorflow/c/eager:c_api",
160        "//tensorflow/c/eager:c_api_experimental",
161        "//tensorflow/core:protos_all_cc",
162        "//tensorflow/core:test",
163        "//tensorflow/core:test_main",
164        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
165    ],
166)
167