• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Contains the Keras API (internal TensorFlow version).
3
4load("//tensorflow:tensorflow.bzl", "tf_py_test")
5
6package(
7    default_visibility = ["//visibility:public"],
8    licenses = ["notice"],  # Apache 2.0
9)
10
11exports_files(["LICENSE"])
12
13py_library(
14    name = "keras",
15    srcs = [
16        "__init__.py",
17        "estimator/__init__.py",
18        "keras_parameterized.py",
19        "ops.py",
20        "testing_utils.py",
21    ],
22    srcs_version = "PY2AND3",
23    visibility = ["//visibility:public"],
24    deps = [
25        ":backend",
26        ":engine",
27        "//tensorflow/python:training",
28        "//tensorflow/python/eager:monitoring",
29        "//tensorflow/python/keras/applications",
30        "//tensorflow/python/keras/datasets",
31        "//tensorflow/python/keras/layers",
32        "//tensorflow/python/keras/mixed_precision/experimental:mixed_precision_experimental",
33        "//tensorflow/python/keras/optimizer_v2",
34        "//tensorflow/python/keras/premade",
35        "//tensorflow/python/keras/preprocessing",
36        "//tensorflow/python/keras/saving",
37        "//tensorflow/python/keras/utils",
38        "//tensorflow/python/keras/wrappers",
39        "//tensorflow/python/saved_model",
40    ],
41)
42
43py_library(
44    name = "backend",
45    srcs = ["backend.py"],
46    srcs_version = "PY2AND3",
47    deps = [
48        ":backend_config",
49        "//tensorflow/core:protos_all_py",
50        "//tensorflow/python:array_ops",
51        "//tensorflow/python:check_ops",
52        "//tensorflow/python:client",
53        "//tensorflow/python:clip_ops",
54        "//tensorflow/python:composite_tensor_utils",
55        "//tensorflow/python:constant_op",
56        "//tensorflow/python:control_flow_ops",
57        "//tensorflow/python:control_flow_util",
58        "//tensorflow/python:ctc_ops",
59        "//tensorflow/python:dtypes",
60        "//tensorflow/python:framework",
61        "//tensorflow/python:framework_ops",
62        "//tensorflow/python:functional_ops",
63        "//tensorflow/python:gradients",
64        "//tensorflow/python:image_ops",
65        "//tensorflow/python:init_ops",
66        "//tensorflow/python:init_ops_v2",
67        "//tensorflow/python:logging_ops",
68        "//tensorflow/python:map_fn",
69        "//tensorflow/python:math_ops",
70        "//tensorflow/python:metrics",
71        "//tensorflow/python:nn",
72        "//tensorflow/python:platform",
73        "//tensorflow/python:random_ops",
74        "//tensorflow/python:session",
75        "//tensorflow/python:sparse_ops",
76        "//tensorflow/python:sparse_tensor",
77        "//tensorflow/python:state_ops",
78        "//tensorflow/python:summary",
79        "//tensorflow/python:tensor_array_grad",
80        "//tensorflow/python:tensor_array_ops",
81        "//tensorflow/python:tensor_shape",
82        "//tensorflow/python:training_lib",
83        "//tensorflow/python:util",
84        "//tensorflow/python:variables",
85        "//tensorflow/python/distribute:distribute_coordinator",
86        "//tensorflow/python/distribute:distribute_lib",
87        "//tensorflow/python/distribute:multi_worker_util",
88    ],
89)
90
91py_library(
92    name = "backend_config",
93    srcs = ["backend_config.py"],
94    srcs_version = "PY2AND3",
95)
96
97# TODO(scottzhu): Cleanup this target and point all the user to keras/engine.
98py_library(
99    name = "engine",
100    srcs = [
101        ":metrics",
102        ":models",
103    ],
104    srcs_version = "PY2AND3",
105    deps = [
106        "//tensorflow/python/keras/engine",
107    ],
108)
109
110py_library(
111    name = "activations",
112    srcs = [
113        "activations.py",
114    ],
115    srcs_version = "PY2AND3",
116    deps = [
117        ":backend",
118        "//tensorflow/python/keras/utils:engine_utils",
119    ],
120)
121
122# TODO(scottzhu): Cleanup this target and point all the user to keras/engine.
123py_library(
124    name = "base_layer",
125    srcs = [],
126    srcs_version = "PY2AND3",
127    deps = [
128        "//tensorflow/python/keras/engine:base_layer",
129    ],
130)
131
132py_library(
133    name = "callbacks",
134    srcs = [
135        "callbacks.py",
136    ],
137    srcs_version = "PY2AND3",
138    deps = [
139        ":backend",
140        "//tensorflow/python/distribute:distributed_file_utils",
141        "//tensorflow/python/keras/distribute:multi_worker_training_state",
142        "//tensorflow/python/keras/utils:engine_utils",
143        "//tensorflow/python/keras/utils:mode_keys",
144        "//tensorflow/tools/docs:doc_controls",
145    ],
146)
147
148py_library(
149    name = "callbacks_v1",
150    srcs = [
151        "callbacks_v1.py",
152    ],
153    srcs_version = "PY2AND3",
154    deps = [
155        ":backend",
156        "//tensorflow/python/eager:profiler",
157        "//tensorflow/python/keras/utils:engine_utils",
158    ],
159)
160
161py_library(
162    name = "constraints",
163    srcs = [
164        "constraints.py",
165    ],
166    srcs_version = "PY2AND3",
167    deps = [
168        ":backend",
169        "//tensorflow/python/keras/utils:engine_utils",
170    ],
171)
172
173py_library(
174    name = "initializers",
175    srcs = [
176        "initializers.py",
177    ],
178    srcs_version = "PY2AND3",
179    deps = [
180        ":backend",
181        "//tensorflow/python:init_ops_v2",
182        "//tensorflow/python/keras/utils:engine_utils",
183    ],
184)
185
186py_library(
187    name = "losses",
188    srcs = [
189        "losses.py",
190    ],
191    srcs_version = "PY2AND3",
192    deps = [
193        ":backend",
194        "//tensorflow/python/keras/utils:engine_utils",
195    ],
196)
197
198py_library(
199    name = "metrics",
200    srcs = [
201        "metrics.py",
202    ],
203    srcs_version = "PY2AND3",
204    deps = [
205        ":backend",
206        ":losses",
207        "//tensorflow/python:array_ops",
208        "//tensorflow/python:check_ops",
209        "//tensorflow/python:confusion_matrix",
210        "//tensorflow/python:constant_op",
211        "//tensorflow/python:control_flow_ops",
212        "//tensorflow/python:dtypes",
213        "//tensorflow/python:framework_ops",
214        "//tensorflow/python:init_ops",
215        "//tensorflow/python:math_ops",
216        "//tensorflow/python:nn",
217        "//tensorflow/python:tensor_shape",
218        "//tensorflow/python:util",
219        "//tensorflow/python:variables",
220        "//tensorflow/python:weights_broadcast_ops",
221        "//tensorflow/python/distribute:distribute_lib",
222        "//tensorflow/python/eager:context",
223        "//tensorflow/python/eager:def_function",
224        "//tensorflow/python/keras/distribute",
225        "//tensorflow/python/keras/engine:base_layer",
226        "//tensorflow/python/keras/engine:base_layer_utils",
227        "//tensorflow/python/keras/utils:generic_utils",
228        "//tensorflow/python/keras/utils:metrics_utils",
229        "//tensorflow/python/keras/utils:tf_utils",
230        "//tensorflow/python/ops/losses",
231        "//tensorflow/tools/docs:doc_controls",
232        "//third_party/py/numpy",
233        "@six_archive//:six",
234    ],
235)
236
237py_library(
238    name = "models",
239    srcs = [
240        "models.py",
241    ],
242    srcs_version = "PY2AND3",
243    deps = [
244        ":backend",
245        ":metrics",
246        ":optimizers",
247        "//tensorflow/python:platform",
248        "//tensorflow/python:util",
249        "//tensorflow/python/keras/engine",
250        "//tensorflow/python/keras/engine:base_layer",
251        "//tensorflow/python/keras/saving",
252        "//tensorflow/python/keras/utils:generic_utils",
253        "//tensorflow/python/keras/utils:version_utils",
254    ],
255)
256
257py_library(
258    name = "optimizers",
259    srcs = [
260        "optimizers.py",
261    ],
262    srcs_version = "PY2AND3",
263    deps = [
264        ":backend",
265        "//tensorflow/python/keras/optimizer_v2",
266        "//tensorflow/python/keras/utils:engine_utils",
267    ],
268)
269
270py_library(
271    name = "regularizers",
272    srcs = [
273        "regularizers.py",
274    ],
275    srcs_version = "PY2AND3",
276    deps = [
277        ":backend",
278        "//tensorflow/python/keras/utils:engine_utils",
279    ],
280)
281
282tf_py_test(
283    name = "activations_test",
284    size = "small",
285    srcs = ["activations_test.py"],
286    python_version = "PY3",
287    deps = [
288        ":keras",
289        "//tensorflow/python:client_testlib",
290        "//tensorflow/python:nn_ops",
291        "//third_party/py/numpy",
292        "@absl_py//absl/testing:parameterized",
293    ],
294)
295
296tf_py_test(
297    name = "constraints_test",
298    size = "small",
299    srcs = ["constraints_test.py"],
300    python_version = "PY3",
301    deps = [
302        ":keras",
303        "//tensorflow/python:client_testlib",
304        "//third_party/py/numpy",
305        "@absl_py//absl/testing:parameterized",
306    ],
307)
308
309tf_py_test(
310    name = "initializers_test",
311    size = "small",
312    srcs = ["initializers_test.py"],
313    python_version = "PY3",
314    deps = [
315        ":keras",
316        "//tensorflow/python:client_testlib",
317        "//tensorflow/python:init_ops",
318        "//third_party/py/numpy",
319        "@absl_py//absl/testing:parameterized",
320    ],
321)
322
323tf_py_test(
324    name = "regularizers_test",
325    size = "medium",
326    srcs = ["regularizers_test.py"],
327    python_version = "PY3",
328    deps = [
329        ":keras",
330        "//tensorflow/python:client_testlib",
331        "@absl_py//absl/testing:parameterized",
332    ],
333)
334
335tf_py_test(
336    name = "optimizers_test",
337    size = "medium",
338    srcs = ["optimizers_test.py"],
339    python_version = "PY3",
340    shard_count = 8,
341    tags = ["notsan"],
342    deps = [
343        ":keras",
344        "//tensorflow/python:client_testlib",
345        "//tensorflow/python:training",
346        "//third_party/py/numpy",
347        "@absl_py//absl/testing:parameterized",
348    ],
349)
350
351tf_py_test(
352    name = "losses_test",
353    size = "small",
354    srcs = ["losses_test.py"],
355    python_version = "PY3",
356    deps = [
357        ":keras",
358        "//tensorflow/python:client_testlib",
359        "//third_party/py/numpy",
360        "@absl_py//absl/testing:parameterized",
361    ],
362)
363
364tf_py_test(
365    name = "metrics_functional_test",
366    size = "small",
367    srcs = ["metrics_functional_test.py"],
368    python_version = "PY3",
369    deps = [
370        ":keras",
371        "//tensorflow/python:client_testlib",
372        "//third_party/py/numpy",
373    ],
374)
375
376tf_py_test(
377    name = "metrics_test",
378    size = "medium",
379    srcs = ["metrics_test.py"],
380    python_version = "PY3",
381    shard_count = 4,
382    deps = [
383        ":keras",
384        "//tensorflow/python:client_testlib",
385        "//third_party/py/numpy",
386        "@absl_py//absl/testing:parameterized",
387    ],
388)
389
390tf_py_test(
391    name = "metrics_confusion_matrix_test",
392    size = "medium",
393    srcs = ["metrics_confusion_matrix_test.py"],
394    python_version = "PY3",
395    shard_count = 4,
396    deps = [
397        ":keras",
398        "//tensorflow/python:client_testlib",
399        "//third_party/py/numpy",
400        "@absl_py//absl/testing:parameterized",
401    ],
402)
403
404tf_py_test(
405    name = "metrics_correctness_test",
406    size = "medium",
407    srcs = ["metrics_correctness_test.py"],
408    python_version = "PY3",
409    shard_count = 4,
410    deps = [
411        ":keras",
412        "//tensorflow/python:client_testlib",
413        "//third_party/py/numpy",
414        "@absl_py//absl/testing:parameterized",
415    ],
416)
417
418tf_py_test(
419    name = "callbacks_test",
420    size = "medium",
421    srcs = ["callbacks_test.py"],
422    python_version = "PY3",
423    shard_count = 4,
424    tags = [
425        "no_oss",
426        "notsan",
427    ],
428    deps = [
429        ":keras",
430        "//tensorflow/python:client_testlib",
431        "//third_party/py/numpy",
432        "@absl_py//absl/testing:parameterized",
433        "@six_archive//:six",
434    ],
435)
436
437tf_py_test(
438    name = "callbacks_v1_test",
439    size = "medium",
440    srcs = ["callbacks_v1_test.py"],
441    python_version = "PY3",
442    tags = ["notsan"],
443    deps = [
444        ":keras",
445        "//tensorflow/python:client_testlib",
446        "//third_party/py/numpy",
447        "@absl_py//absl/testing:parameterized",
448    ],
449)
450
451tf_py_test(
452    name = "models_test",
453    size = "medium",
454    srcs = ["models_test.py"],
455    python_version = "PY3",
456    shard_count = 8,
457    tags = [
458        "no_rocm",
459        "notsan",  # b/67509773
460    ],
461    deps = [
462        ":keras",
463        "//tensorflow/python:client_testlib",
464        "//tensorflow/python:training",
465        "//third_party/py/numpy",
466        "@absl_py//absl/testing:parameterized",
467    ],
468)
469
470tf_py_test(
471    name = "backend_test",
472    size = "medium",
473    srcs = ["backend_test.py"],
474    python_version = "PY3",
475    shard_count = 4,
476    deps = [
477        ":keras",
478        "//tensorflow/python:client_testlib",
479        "//tensorflow/python:util",
480        "//third_party/py/numpy",
481        "@absl_py//absl/testing:parameterized",
482    ],
483)
484
485tf_py_test(
486    name = "backend_config_test",
487    size = "medium",
488    srcs = ["backend_config_test.py"],
489    python_version = "PY3",
490    deps = [
491        ":keras",
492        "//tensorflow/python:client_testlib",
493        "//tensorflow/python:util",
494        "//third_party/py/numpy",
495    ],
496)
497
498tf_py_test(
499    name = "keras_parameterized_test",
500    size = "small",
501    srcs = ["keras_parameterized_test.py"],
502    python_version = "PY3",
503    tags = ["notsan"],
504    deps = [
505        ":keras",
506        "//tensorflow/python:client_testlib",
507        "//third_party/py/numpy",
508        "@absl_py//absl/testing:parameterized",
509    ],
510)
511