• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Trackable class and subclass definitions.
3
4load("//tensorflow:tensorflow.bzl", "tf_py_test")
5
6package(
7    default_visibility = [
8        "//tensorflow:internal",
9    ],
10    licenses = ["notice"],
11)
12
13py_library(
14    name = "trackable",
15    deps = [
16        ":asset",
17        ":autotrackable",
18        ":base",
19        ":base_delegate",
20        ":constants",
21        ":converter",
22        ":data_structures",
23        ":layer_utils",
24        ":python_state",
25        ":resource",
26        ":trackable_init",
27        ":trackable_utils",
28    ],
29)
30
31py_library(
32    name = "trackable_init",
33    srcs = ["__init__.py"],
34    srcs_version = "PY3",
35)
36
37py_library(
38    name = "base",
39    srcs = ["base.py"],
40    srcs_version = "PY3",
41    deps = [
42        ":constants",
43        "//tensorflow/python:control_flow_ops",
44        "//tensorflow/python:dtypes",
45        "//tensorflow/python:framework_ops",
46        "//tensorflow/python:util",
47        "//tensorflow/python/eager:context",
48        "//tensorflow/python/training/saving:saveable_object",
49    ],
50)
51
52tf_py_test(
53    name = "base_test",
54    srcs = ["base_test.py"],
55    deps = [
56        ":base",
57        "//tensorflow/python:client_testlib",
58    ],
59)
60
61py_library(
62    name = "constants",
63    srcs = ["constants.py"],
64    srcs_version = "PY3",
65)
66
67py_library(
68    name = "converter",
69    srcs = ["converter.py"],
70    srcs_version = "PY3",
71    deps = [
72        ":data_structures",
73        "//tensorflow/python/eager:function_saved_model_utils",
74    ],
75)
76
77py_library(
78    name = "trackable_utils",
79    srcs = ["trackable_utils.py"],
80    srcs_version = "PY3",
81)
82
83tf_py_test(
84    name = "trackable_utils_test",
85    srcs = ["trackable_utils_test.py"],
86    deps = [
87        ":trackable_utils",
88        "//tensorflow/python/eager:test",
89    ],
90)
91
92py_library(
93    name = "base_delegate",
94    srcs = ["base_delegate.py"],
95    srcs_version = "PY3",
96    deps = [
97        "//tensorflow/python/util:tf_export",
98    ],
99)
100
101tf_py_test(
102    name = "base_delegate_test",
103    srcs = ["base_delegate_test.py"],
104    deps = [
105        ":base",
106        ":base_delegate",
107        "//tensorflow/python:extra_py_tests_deps",
108        "//tensorflow/python:framework_test_lib",
109        "//tensorflow/python:variables",
110        "//tensorflow/python/checkpoint",
111        "//tensorflow/python/eager:test",
112        "//tensorflow/python/saved_model:load",
113        "//tensorflow/python/saved_model:save",
114    ],
115)
116
117py_library(
118    name = "asset",
119    srcs = ["asset.py"],
120    srcs_version = "PY3",
121    deps = [
122        ":base",
123        "//tensorflow/python:lib",
124        "//tensorflow/python/eager:context",
125    ],
126)
127
128py_library(
129    name = "autotrackable",
130    srcs = ["autotrackable.py"],
131    srcs_version = "PY3",
132    deps = [
133        ":base",
134        ":data_structures",
135    ],
136)
137
138tf_py_test(
139    name = "autotrackable_test",
140    srcs = ["autotrackable_test.py"],
141    deps = [
142        ":autotrackable",
143        ":data_structures",
144        "//tensorflow/python:client_testlib",
145    ],
146)
147
148py_library(
149    name = "resource",
150    srcs = ["resource.py"],
151    srcs_version = "PY3",
152    # TODO(b/238780047): Clean up the Grand Vision code and
153    # revert to only default_visibility here once automated
154    # tracking of resources is implemented.
155    visibility = [
156        "//tensorflow:internal",
157        "//third_party/py/grand_vision/google:__subpackages__",
158    ],
159    deps = [
160        ":base",
161    ],
162)
163
164tf_py_test(
165    name = "resource_test",
166    srcs = ["resource_test.py"],
167    deps = [
168        ":base",
169        "//tensorflow/python:client_testlib",
170    ],
171)
172
173py_library(
174    name = "layer_utils",
175    srcs = ["layer_utils.py"],
176    srcs_version = "PY3",
177)
178
179py_library(
180    name = "data_structures",
181    srcs = ["data_structures.py"],
182    srcs_version = "PY3",
183    deps = [
184        ":base",
185        ":layer_utils",
186        "//tensorflow/python/saved_model:revived_types",
187        "@wrapt",
188    ],
189)
190
191tf_py_test(
192    name = "data_structures_test",
193    srcs = ["data_structures_test.py"],
194    tags = [
195        "no_windows",
196        "nomac",
197    ],
198    deps = [
199        ":data_structures",
200        "//tensorflow/python:array_ops",
201        "//tensorflow/python:framework_test_lib",
202        "//tensorflow/python:layers",
203        "//tensorflow/python:math_ops",
204        "//tensorflow/python/eager:context",
205        "//tensorflow/python/eager:test",
206    ],
207)
208
209py_library(
210    name = "python_state",
211    srcs = ["python_state.py"],
212    srcs_version = "PY3",
213    deps = [
214        ":base",
215        "//tensorflow/python/util:tf_export",
216    ],
217)
218
219tf_py_test(
220    name = "python_state_test",
221    srcs = ["python_state_test.py"],
222    deps = [
223        ":python_state",
224        "//tensorflow/python:client_testlib",
225        "//tensorflow/python:framework_test_lib",
226        "//tensorflow/python/checkpoint",
227        "//tensorflow/python/module",
228    ],
229)
230