• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Experimental Unified APIs for Eager and Graph modes.
2
3# buildifier: disable=same-origin-load
4load("//tensorflow:tensorflow.bzl", "cuda_py_test")
5
6# buildifier: disable=same-origin-load
7load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
8
9package(
10    default_visibility = ["//tensorflow:internal"],
11    licenses = ["notice"],
12)
13
14tf_python_pybind_extension(
15    name = "_unified_api",
16    srcs = ["unified_api.cc"],
17    features = ["-layering_check"],
18    module_name = "_unified_api",
19    deps = [
20        "//tensorflow/c/eager:tfe_tensorhandle_internal",
21        "//tensorflow/core:lib",
22        "//tensorflow/core:protos_all_cc",
23        "//tensorflow/core/lib/llvm_rtti",
24        "//tensorflow/python:pybind11_lib",
25        "//tensorflow/python:unified_api_pywrap_required_headers",
26        "@pybind11",
27    ],
28)
29
30tf_python_pybind_extension(
31    name = "_tape",
32    srcs = ["tape.cc"],
33    features = ["-layering_check"],
34    module_name = "_tape",
35    deps = [
36        "//tensorflow/c/eager:tfe_tensorhandle_internal",
37        "//tensorflow/core:lib",
38        "//tensorflow/core:protos_all_cc",
39        "//tensorflow/core/lib/llvm_rtti",
40        "//tensorflow/python:pybind11_lib",
41        "//tensorflow/python:unified_api_pywrap_required_headers",
42        "@pybind11",
43    ],
44)
45
46tf_python_pybind_extension(
47    name = "_math_ops",
48    srcs = ["math_ops.cc"],
49    module_name = "_math_ops",
50    deps = [
51        "//tensorflow/c/eager:tfe_tensorhandle_internal",
52        "//tensorflow/core:framework",
53        "//tensorflow/core:lib",
54        "//tensorflow/core:protos_all_cc",
55        "//tensorflow/core/lib/llvm_rtti",
56        "//tensorflow/python:pybind11_lib",
57        "//tensorflow/python:unified_api_pywrap_required_headers",
58        "@com_google_absl//absl/types:span",
59        "@pybind11",
60    ],
61)
62
63tf_python_pybind_extension(
64    name = "_nn_ops",
65    srcs = ["nn_ops.cc"],
66    module_name = "_nn_ops",
67    deps = [
68        "//tensorflow/c/eager:tfe_tensorhandle_internal",
69        "//tensorflow/core:framework",
70        "//tensorflow/core:lib",
71        "//tensorflow/core:protos_all_cc",
72        "//tensorflow/core/lib/llvm_rtti",
73        "//tensorflow/python:pybind11_lib",
74        "//tensorflow/python:unified_api_pywrap_required_headers",
75        "@com_google_absl//absl/types:span",
76        "@pybind11",
77    ],
78)
79
80py_library(
81    name = "gradient_registry",
82    srcs = ["gradient_registry.py"],
83    srcs_version = "PY3",
84    deps = [":_tape"],
85)
86
87py_library(
88    name = "math_ops",
89    srcs = ["math_ops.py"],
90    srcs_version = "PY3",
91    deps = [
92        ":_math_ops",
93        ":context_stack",
94    ],
95)
96
97py_library(
98    name = "nn_ops",
99    srcs = ["nn_ops.py"],
100    srcs_version = "PY3",
101    deps = [
102        ":_nn_ops",
103        ":context_stack",
104    ],
105)
106
107py_library(
108    name = "tape",
109    srcs = ["tape.py"],
110    srcs_version = "PY3",
111    deps = [
112        ":_tape",
113        ":context_stack",
114        ":gradient_registry",
115        "//tensorflow/python/data/util:nest",
116    ],
117)
118
119py_library(
120    name = "def_function",
121    srcs = ["def_function.py"],
122    srcs_version = "PY3",
123)
124
125py_library(
126    name = "thread_local_stack",
127    srcs = ["thread_local_stack.py"],
128    srcs_version = "PY3",
129)
130
131py_library(
132    name = "context_stack",
133    srcs = ["context_stack.py"],
134    srcs_version = "PY3",
135    deps = [":thread_local_stack"],
136)
137
138cuda_py_test(
139    name = "unified_api_test",
140    size = "small",
141    srcs = ["unified_api_test.py"],
142    tags = [
143        # Note(srbs): These python bindings are not
144        # exported as part of the pip package yet so
145        # this test is disabled.
146        "no_pip",
147        "no_windows",  # b/168218876
148    ],
149    deps = [
150        ":_unified_api",
151        ":context_stack",
152        ":def_function",
153        ":math_ops",
154        ":nn_ops",
155        ":tape",
156        "//tensorflow/python:client_testlib",
157        "@absl_py//absl/testing:parameterized",
158    ],
159)
160