• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Contains the Keras Utilities (internal TensorFlow version).
3
4load("//tensorflow:tensorflow.bzl", "tf_py_test")
5load("//tensorflow:tensorflow.bzl", "cuda_py_test")
6
7package(
8    # TODO(scottzhu): Remove non-keras deps from TF.
9    default_visibility = [
10        "//tensorflow/python/feature_column:__pkg__",
11        "//tensorflow/python/keras:__subpackages__",
12        "//tensorflow/tools/pip_package:__pkg__",
13    ],
14    licenses = ["notice"],  # Apache 2.0
15)
16
17filegroup(
18    name = "all_py_srcs",
19    srcs = glob(["*.py"]),
20    visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
21)
22
23py_library(
24    name = "utils",
25    srcs = [
26        "__init__.py",
27    ],
28    srcs_version = "PY3",
29    deps = [
30        ":all_utils",
31    ],
32)
33
34py_library(
35    name = "all_utils",
36    srcs = [
37        "all_utils.py",
38    ],
39    srcs_version = "PY3",
40    deps = [
41        ":control_flow_util",
42        ":engine_utils",
43        ":generic_utils",
44        ":layer_utils",
45        ":multi_gpu_utils",
46        ":np_utils",
47        ":vis_utils",
48    ],
49)
50
51py_library(
52    name = "control_flow_util",
53    srcs = ["control_flow_util.py"],
54    srcs_version = "PY3",
55    deps = [],
56)
57
58py_library(
59    name = "kpl_test_utils",
60    srcs = ["kpl_test_utils.py"],
61    srcs_version = "PY3",
62    deps = [
63        "//tensorflow/python/keras",
64        "//tensorflow/python/keras/layers/preprocessing:string_lookup",
65    ],
66)
67
68py_library(
69    name = "data_utils",
70    srcs = ["data_utils.py"],
71    srcs_version = "PY3",
72    deps = [
73        ":generic_utils",
74        ":io_utils",
75        ":tf_inspect",
76    ],
77)
78
79py_library(
80    name = "engine_utils",
81    srcs = [
82        "conv_utils.py",
83        "losses_utils.py",
84    ],
85    srcs_version = "PY3",
86    deps = [
87        ":data_utils",
88        ":io_utils",
89        "//tensorflow/python/keras:backend",
90    ],
91)
92
93py_library(
94    name = "io_utils",
95    srcs = ["io_utils.py"],
96    srcs_version = "PY3",
97    deps = [
98        "@six_archive//:six",
99    ],
100)
101
102py_library(
103    name = "tf_utils",
104    srcs = ["tf_utils.py"],
105    srcs_version = "PY3",
106    deps = [
107        ":object_identity",
108        "//tensorflow/python:composite_tensor",
109        "//tensorflow/python:control_flow_ops",
110        "//tensorflow/python:framework_ops",
111        "//tensorflow/python:smart_cond",
112        "//tensorflow/python:tensor_shape",
113        "//tensorflow/python:tensor_util",
114        "//tensorflow/python:util",
115        "//tensorflow/python:variables",
116        "//tensorflow/python/eager:context",
117        "@six_archive//:six",
118    ],
119)
120
121py_library(
122    name = "generic_utils",
123    srcs = [
124        "generic_utils.py",
125    ],
126    srcs_version = "PY3",
127    deps = [
128        ":tf_contextlib",
129        ":tf_inspect",
130        "//tensorflow/python:util",
131        "//third_party/py/numpy",
132    ],
133)
134
135py_library(
136    name = "mode_keys",
137    srcs = [
138        "mode_keys.py",
139    ],
140    srcs_version = "PY3",
141    deps = [
142        "//tensorflow/python/saved_model/model_utils:mode_keys",
143    ],
144)
145
146py_library(
147    name = "layer_utils",
148    srcs = [
149        "kernelized_utils.py",
150        "layer_utils.py",
151    ],
152    srcs_version = "PY3",
153    deps = [
154        ":engine_utils",
155        "//tensorflow/python:util",
156        "//tensorflow/python/keras:backend",
157        "//third_party/py/numpy",
158    ],
159)
160
161py_library(
162    name = "metrics_utils",
163    srcs = [
164        "metrics_utils.py",
165    ],
166    srcs_version = "PY3",
167    deps = [
168        ":generic_utils",
169        ":tf_utils",
170        "//tensorflow/python:array_ops",
171        "//tensorflow/python:check_ops",
172        "//tensorflow/python:control_flow_ops",
173        "//tensorflow/python:distribute",
174        "//tensorflow/python:dtypes",
175        "//tensorflow/python:framework",
176        "//tensorflow/python:math_ops",
177        "//tensorflow/python:nn_ops",
178        "//tensorflow/python:util",
179        "//tensorflow/python:weights_broadcast_ops",
180        "//tensorflow/python/ops/losses",
181        "//tensorflow/python/ops/ragged:ragged_tensor",
182        "//tensorflow/python/ops/ragged:ragged_util",
183        "//tensorflow/python/tpu:tpu_lib",
184    ],
185)
186
187py_library(
188    name = "version_utils",
189    srcs = [
190        "version_utils.py",
191    ],
192    srcs_version = "PY3",
193    deps = [
194        "//tensorflow/python:framework_ops",
195        "//tensorflow/python:util",
196    ],
197)
198
199py_library(
200    name = "multi_gpu_utils",
201    srcs = [
202        "multi_gpu_utils.py",
203    ],
204    srcs_version = "PY3",
205    deps = [
206        "//tensorflow/python:array_ops",
207        "//tensorflow/python:framework_ops",
208        "//tensorflow/python:util",
209        "//tensorflow/python/keras:backend",
210        "//tensorflow/python/keras/layers",
211    ],
212)
213
214py_library(
215    name = "np_utils",
216    srcs = [
217        "np_utils.py",
218    ],
219    srcs_version = "PY3",
220    deps = [
221        "//tensorflow/python:util",
222        "//third_party/py/numpy",
223    ],
224)
225
226py_library(
227    name = "object_identity",
228    srcs = ["object_identity.py"],
229    srcs_version = "PY3",
230    deps = [],
231)
232
233py_library(
234    name = "tf_contextlib",
235    srcs = ["tf_contextlib.py"],
236    srcs_version = "PY3",
237    deps = [
238        "//tensorflow/python:util",
239    ],
240)
241
242py_library(
243    name = "tf_inspect",
244    srcs = ["tf_inspect.py"],
245    srcs_version = "PY3",
246    deps = [
247        "//tensorflow/python:util",
248    ],
249)
250
251py_library(
252    name = "vis_utils",
253    srcs = [
254        "vis_utils.py",
255    ],
256    srcs_version = "PY3",
257    deps = [
258        "//tensorflow/python:util",
259    ],
260)
261
262py_library(
263    name = "dataset_creator",
264    srcs = [
265        "dataset_creator.py",
266    ],
267    srcs_version = "PY3",
268    deps = [
269        "//tensorflow/python:util",
270    ],
271)
272
273tf_py_test(
274    name = "dataset_creator_test",
275    srcs = ["dataset_creator_test.py"],
276    python_version = "PY3",
277    tags = [
278        "no_tfrt",  # TODO(b/180537361): Reenable TFRT after the issue is resolved.
279    ],
280    deps = [
281        ":dataset_creator",
282        "//tensorflow/python/distribute:multi_worker_test_base",
283        "//tensorflow/python/keras/engine",
284        "//tensorflow/python/keras/layers:core",
285    ],
286)
287
288tf_py_test(
289    name = "data_utils_test",
290    size = "medium",
291    srcs = ["data_utils_test.py"],
292    python_version = "PY3",
293    shard_count = 6,
294    tags = [
295        "noasan",  # times out
296        "notsan",
297        "optonly",  # times out
298    ],
299    deps = [
300        "//tensorflow/python:client_testlib",
301        "//tensorflow/python/keras",
302        "//third_party/py/numpy",
303        "@absl_py//absl/testing:parameterized",
304    ],
305)
306
307tf_py_test(
308    name = "generic_utils_test",
309    size = "small",
310    srcs = ["generic_utils_test.py"],
311    python_version = "PY3",
312    deps = [
313        ":generic_utils",
314        "//tensorflow/python:client_testlib",
315        "//tensorflow/python/keras",
316        "@absl_py//absl/testing:parameterized",
317    ],
318)
319
320tf_py_test(
321    name = "version_utils_test",
322    size = "small",
323    srcs = ["version_utils_test.py"],
324    python_version = "PY3",
325    deps = [
326        ":version_utils",
327        "//tensorflow/python:client_testlib",
328        "//tensorflow/python/keras",
329        "@absl_py//absl/testing:parameterized",
330    ],
331)
332
333tf_py_test(
334    name = "tf_utils_test",
335    size = "small",
336    srcs = ["tf_utils_test.py"],
337    python_version = "PY3",
338    deps = [
339        ":tf_utils",
340        "//tensorflow/python:client_testlib",
341        "//tensorflow/python/keras",
342        "//tensorflow/python/keras:combinations",
343    ],
344)
345
346tf_py_test(
347    name = "composite_tensor_support_test",
348    size = "medium",
349    srcs = ["composite_tensor_support_test.py"],
350    python_version = "PY3",
351    shard_count = 8,
352    deps = [
353        "//tensorflow/python:array_ops",
354        "//tensorflow/python:client_testlib",
355        "//tensorflow/python:dtypes",
356        "//tensorflow/python:framework_ops",
357        "//tensorflow/python:framework_test_lib",
358        "//tensorflow/python:math_ops",
359        "//tensorflow/python:sparse_ops",
360        "//tensorflow/python:sparse_tensor",
361        "//tensorflow/python/keras",
362        "//tensorflow/python/keras:engine",
363        "//tensorflow/python/keras/layers",
364        "//tensorflow/python/ops/ragged:ragged_tensor",
365        "//third_party/py/numpy",
366        "@absl_py//absl/testing:parameterized",
367    ],
368)
369
370tf_py_test(
371    name = "io_utils_test",
372    size = "small",
373    srcs = ["io_utils_test.py"],
374    python_version = "PY3",
375    tags = [
376        "no_windows",  # TODO: needs investigation on Windows
377        "notsan",
378    ],
379    deps = [
380        "//tensorflow/python:client_testlib",
381        "//tensorflow/python/keras",
382        "//third_party/py/numpy",
383        "@absl_py//absl/testing:parameterized",
384    ],
385)
386
387tf_py_test(
388    name = "layer_utils_test",
389    size = "small",
390    srcs = ["layer_utils_test.py"],
391    python_version = "PY3",
392    deps = [
393        ":layer_utils",
394        "//tensorflow/python:client_testlib",
395        "//tensorflow/python/training/tracking",
396        "//third_party/py/numpy",
397    ],
398)
399
400tf_py_test(
401    name = "np_utils_test",
402    size = "small",
403    srcs = ["np_utils_test.py"],
404    python_version = "PY3",
405    deps = [
406        "//tensorflow/python:client_testlib",
407        "//tensorflow/python/keras",
408        "//third_party/py/numpy",
409        "@absl_py//absl/testing:parameterized",
410    ],
411)
412
413tf_py_test(
414    name = "kernelized_utils_test",
415    size = "small",
416    srcs = ["kernelized_utils_test.py"],
417    python_version = "PY3",
418    deps = [
419        ":layer_utils",
420        "//tensorflow/python:client_testlib",
421        "//tensorflow/python:constant_op",
422        "//tensorflow/python:layers",
423        "@absl_py//absl/testing:parameterized",
424    ],
425)
426
427cuda_py_test(
428    name = "multi_gpu_utils_test",
429    srcs = ["multi_gpu_utils_test.py"],
430    python_version = "PY3",
431    tags = [
432        "guitar",
433        "multi_gpu",
434    ],
435    xla_enable_strict_auto_jit = True,
436    deps = [
437        "//tensorflow/python:client_testlib",
438        "//tensorflow/python/keras",
439        "//third_party/py/numpy",
440        "@absl_py//absl/testing:parameterized",
441    ],
442)
443
444tf_py_test(
445    name = "vis_utils_test",
446    size = "small",
447    srcs = ["vis_utils_test.py"],
448    python_version = "PY3",
449    deps = [
450        "//tensorflow/python:client_testlib",
451        "//tensorflow/python/keras",
452        "//third_party/py/numpy",
453        "@absl_py//absl/testing:parameterized",
454    ],
455)
456
457tf_py_test(
458    name = "conv_utils_test",
459    size = "small",
460    srcs = ["conv_utils_test.py"],
461    python_version = "PY3",
462    deps = [
463        "//tensorflow/python:client_testlib",
464        "//tensorflow/python/keras",
465        "//third_party/py/numpy",
466        "@absl_py//absl/testing:parameterized",
467    ],
468)
469
470tf_py_test(
471    name = "metrics_utils_test",
472    size = "small",
473    srcs = ["metrics_utils_test.py"],
474    python_version = "PY3",
475    deps = [
476        "//tensorflow/python:constant_op",
477        "//tensorflow/python:framework_ops",
478        "//tensorflow/python:framework_test_lib",
479        "//tensorflow/python:ops",
480        "//tensorflow/python:platform_test",
481        "//tensorflow/python/eager:context",
482        "//tensorflow/python/keras",
483        "//tensorflow/python/keras:combinations",
484        "//tensorflow/python/ops/ragged:ragged_factory_ops",
485        "//tensorflow/python/ops/ragged:ragged_tensor",
486        "@absl_py//absl/testing:parameterized",
487    ],
488)
489
490tf_py_test(
491    name = "losses_utils_test",
492    size = "small",
493    srcs = ["losses_utils_test.py"],
494    python_version = "PY3",
495    deps = [
496        "//tensorflow/python:constant_op",
497        "//tensorflow/python:framework_ops",
498        "//tensorflow/python:framework_test_lib",
499        "//tensorflow/python:ops",
500        "//tensorflow/python:platform_test",
501        "//tensorflow/python/eager:context",
502        "//tensorflow/python/keras",
503        "//tensorflow/python/keras:combinations",
504        "//tensorflow/python/ops/ragged:ragged_array_ops",
505        "//tensorflow/python/ops/ragged:ragged_concat_ops",
506        "//tensorflow/python/ops/ragged:ragged_factory_ops",
507    ],
508)
509