• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description: Operations defined for Cloud TPUs
2
3load("//tensorflow:tensorflow.bzl", "tf_py_test")
4load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
5load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
6
7# Do not add anymore paths here. You do not need to be in the visibility list
8# to use TPU symbols. They are accessible from tf.contrib.tpu in TF 1.x and
9# tf.tpu and tf.compat.v1.tpu in TF 2.x.
10package(
11    default_visibility = [
12        "//learning/brain:__subpackages__",
13        "//learning/deepmind:__subpackages__",
14        "//research/graph:__subpackages__",
15        "//tensorflow:__subpackages__",
16    ],
17    licenses = ["notice"],  # Apache 2.0
18)
19
20exports_files(["tpu_test_wrapper.py"])
21
22py_test(
23    name = "tpu_test_wrapper_test",
24    srcs = [
25        "tpu_test_wrapper.py",
26        "tpu_test_wrapper_test.py",
27    ],
28    main = "tpu_test_wrapper_test.py",
29    python_version = "PY3",
30    srcs_version = "PY3",
31    tags = [
32        "no_oss_py2",
33        "no_oss_py35",
34        "no_pip",
35    ],
36    deps = [
37        "//tensorflow/python:client_testlib",
38        "//tensorflow/python:platform",
39        "@absl_py//absl/testing:flagsaver",
40    ],
41)
42
43py_library(
44    name = "tpu_py",
45    srcs = ["ops/tpu_ops.py"],
46    srcs_version = "PY2AND3",
47    deps = [
48        "//tensorflow/python:framework_for_generated_wrappers",
49        "//tensorflow/python:tpu_ops_gen",
50    ],
51)
52
53py_library(
54    name = "async_checkpoint",
55    srcs = ["async_checkpoint.py"],
56    srcs_version = "PY2AND3",
57    deps = [
58        "//tensorflow/python:array_ops",
59        "//tensorflow/python:control_flow_ops",
60        "//tensorflow/python:framework_for_generated_wrappers",
61        "//tensorflow/python:init_ops",
62        "//tensorflow/python:math_ops",
63        "//tensorflow/python:platform",
64        "//tensorflow/python:state_ops",
65        "//tensorflow/python:summary",
66        "//tensorflow/python:summary_ops_v2",
67        "//tensorflow/python:training",
68        "//tensorflow/python:variable_scope",
69        "//tensorflow/python:variables",
70        "//tensorflow/python/estimator:estimator_py",
71    ],
72)
73
74tpu_py_test(
75    name = "async_checkpoint_test",
76    size = "medium",
77    srcs = ["async_checkpoint_test.py"],
78    disable_experimental = True,
79    deps = [
80        ":async_checkpoint",
81        ":tpu_estimator",
82        ":tpu_lib",
83        "//tensorflow/python:lib",
84        "//tensorflow/python:platform",
85        "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
86        "//third_party/py/numpy",
87    ],
88)
89
90py_library(
91    name = "preempted_hook_py",
92    srcs = ["preempted_hook.py"],
93    srcs_version = "PY2AND3",
94    deps = [
95        "//tensorflow/python:errors",
96        "//tensorflow/python:platform",
97        "//tensorflow/python:session_run_hook",
98        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
99    ],
100)
101
102py_library(
103    name = "tpu_estimator",
104    srcs = [
105        "_tpu_estimator_embedding.py",
106        "error_handling.py",
107        "tpu_config.py",
108        "tpu_context.py",
109        "tpu_estimator.py",
110        "util.py",
111    ],
112    srcs_version = "PY2AND3",
113    deps = [
114        ":async_checkpoint",
115        ":feature_column",
116        ":feature_column_v2",
117        ":functional",
118        ":preempted_hook_py",
119        ":tpu_embedding",
120        ":tpu_lib",
121        "//tensorflow/core:protos_all_py",
122        "//tensorflow/python:array_ops",
123        "//tensorflow/python:control_flow_ops",
124        "//tensorflow/python:framework_for_generated_wrappers",
125        "//tensorflow/python:function",
126        "//tensorflow/python:init_ops",
127        "//tensorflow/python:math_ops",
128        "//tensorflow/python:platform",
129        "//tensorflow/python:session",
130        "//tensorflow/python:state_ops",
131        "//tensorflow/python:summary",
132        "//tensorflow/python:summary_ops_v2",
133        "//tensorflow/python:training",
134        "//tensorflow/python:variable_scope",
135        "//tensorflow/python:variables",
136        "//tensorflow/python/estimator:estimator_py",
137        "//tensorflow/python/estimator:util",
138        "@six_archive//:six",
139    ],
140)
141
142py_library(
143    name = "functional",
144    srcs = ["functional.py"],
145    srcs_version = "PY2AND3",
146    visibility = [
147        "//visibility:public",
148    ],
149    deps = [
150        "//tensorflow/python:tpu_ops_gen",
151    ],
152)
153
154py_library(
155    name = "tpu",
156    srcs = [
157        "__init__.py",
158    ],
159    srcs_version = "PY2AND3",
160    deps = [
161        ":feature_column",
162        ":feature_column_v2",
163        ":tpu_embedding",
164        ":tpu_estimator",
165        ":tpu_lib",
166    ],
167)
168
169py_library(
170    name = "tpu_noestimator",
171    srcs = [
172        "__init__.py",
173        "api.py",
174    ],
175    srcs_version = "PY2AND3",
176    deps = [
177        ":feature_column",
178        ":feature_column_v2",
179        ":preempted_hook_py",
180        ":tpu_embedding",
181        ":tpu_lib",
182    ],
183)
184
185py_library(
186    name = "tpu_lib",
187    srcs = [
188        "__init__.py",
189        "bfloat16.py",
190        "device_assignment.py",
191        "session_support.py",
192        "tensor_tracer.py",
193        "tensor_tracer_flags.py",
194        "tensor_tracer_report.py",
195        "topology.py",
196        "tpu.py",
197        "tpu_feed.py",
198        "tpu_function.py",
199        "tpu_optimizer.py",
200        "tpu_sharding.py",
201        "tpu_strategy_util.py",
202        "tpu_system_metadata.py",
203        "training_loop.py",
204    ],
205    srcs_version = "PY2AND3",
206    deps = [
207        ":datasets",
208        ":functional",
209        ":tpu_py",
210        "//tensorflow/compiler/xla/experimental/xla_sharding",
211        "//tensorflow/compiler/xla/python_api:xla_shape",
212        "//tensorflow/core:protos_all_py",
213        "//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
214        "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
215        "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
216        "//tensorflow/core/protobuf/tpu:topology_proto_py",
217        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
218        "//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
219        "//tensorflow/python:array_ops",
220        "//tensorflow/python:batch_ops",
221        "//tensorflow/python:control_flow_ops",
222        "//tensorflow/python:control_flow_util",
223        "//tensorflow/python:dtypes",
224        "//tensorflow/python:framework",
225        "//tensorflow/python:framework_ops",
226        "//tensorflow/python:platform_analytics",
227        "//tensorflow/python:tensor_shape",
228        "//tensorflow/python:tpu_ops_gen",
229        "//tensorflow/python:training",
230        "//tensorflow/python:util",
231        "//tensorflow/python:variable_scope",
232        "//tensorflow/python/compiler/xla",
233        "//tensorflow/python/ops/losses",
234        "//tensorflow/python/tpu:tensor_tracer_proto_py",
235        "//tensorflow/python/tpu/profiler",
236        "@six_archive//:six",
237    ],
238)
239
240py_library(
241    name = "datasets",
242    srcs = [
243        "datasets.py",
244    ],
245    srcs_version = "PY2AND3",
246    deps = [
247        "//tensorflow/python:dtypes",
248        "//tensorflow/python:function",
249        "//tensorflow/python:functional_ops",
250        "//tensorflow/python/data/ops:dataset_ops",
251        "//tensorflow/python/data/ops:iterator_ops",
252        "//tensorflow/python/data/ops:readers",
253    ],
254)
255
256tf_py_test(
257    name = "datasets_test",
258    size = "medium",
259    srcs = ["datasets_test.py"],
260    grpc_enabled = True,
261    shard_count = 4,
262    tags = ["no_oss"],
263    deps = [
264        ":datasets",
265        "//tensorflow/python:client_testlib",
266    ],
267)
268
269tf_py_test(
270    name = "tpu_test",
271    size = "small",
272    srcs = ["tpu_test.py"],
273    tags = [
274        "no_oss",  # TODO(b/131157871): Reenable in OSS when fixed
275        "no_windows",  # TODO: needs investigation on Windows
276    ],
277    deps = [
278        ":tpu",
279        "//tensorflow/python:client_testlib",
280        "//tensorflow/python:dtypes",
281        "//tensorflow/python:framework",
282        "//tensorflow/python:layers",
283    ],
284)
285
286tf_py_test(
287    name = "tpu_sharding_test",
288    size = "small",
289    srcs = ["tpu_sharding_test.py"],
290    deps = [
291        ":tpu",
292        "//tensorflow/python:client_testlib",
293        "//tensorflow/python:framework",
294    ],
295)
296
297tf_py_test(
298    name = "bfloat16_test",
299    size = "small",
300    srcs = ["bfloat16_test.py"],
301    deps = [
302        ":tpu",
303        "//tensorflow/python:client_testlib",
304        "//tensorflow/python:framework",
305    ],
306)
307
308tf_py_test(
309    name = "tpu_infeed_test",
310    size = "small",
311    srcs = ["tpu_infeed_test.py"],
312    deps = [
313        ":tpu",
314        "//tensorflow/python:framework",
315        "//tensorflow/python:framework_test_lib",
316    ],
317)
318
319tf_py_test(
320    name = "topology_test",
321    size = "medium",
322    srcs = ["topology_test.py"],
323    deps = [
324        ":tpu",
325        "//tensorflow/python:framework_test_lib",
326    ],
327)
328
329py_library(
330    name = "tpu_embedding",
331    srcs = [
332        "tpu_embedding.py",
333        "tpu_embedding_gradient.py",
334    ],
335    srcs_version = "PY2AND3",
336    deps = [
337        ":tpu_lib",
338        "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
339        "//tensorflow/python:array_ops",
340        "//tensorflow/python:framework_for_generated_wrappers",
341        "//tensorflow/python:init_ops",
342        "//tensorflow/python:math_ops",
343        "//tensorflow/python:partitioned_variables",
344        "//tensorflow/python:tpu_ops_gen",
345        "//tensorflow/python:variable_scope",
346        "//tensorflow/python:variables",
347        "@six_archive//:six",
348    ],
349)
350
351py_library(
352    name = "tpu_strategy_util",
353    srcs = ["tpu_strategy_util.py"],
354    deps = [
355        ":tpu_lib",
356        "//tensorflow/python:dtypes",
357        "//tensorflow/python:framework_ops",
358        "//tensorflow/python:util",
359        "//tensorflow/python/distribute:device_util",
360        "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py",
361        "//tensorflow/python/eager:context",
362        "//tensorflow/python/eager:tape",
363    ],
364)
365
366py_library(
367    name = "feature_column",
368    srcs = ["feature_column.py"],
369    deps = [
370        ":tpu_lib",
371        "//tensorflow/python:framework_ops",
372        "//tensorflow/python:init_ops",
373        "//tensorflow/python:variable_scope",
374        "//tensorflow/python/feature_column",
375        "//tensorflow/python/feature_column:feature_column_py",
376    ],
377)
378
379py_library(
380    name = "feature_column_v2",
381    srcs = ["feature_column_v2.py"],
382    deps = [
383        ":feature_column",
384        ":tpu_lib",
385        "//tensorflow/python:framework_ops",
386        "//tensorflow/python:init_ops",
387        "//tensorflow/python:variable_scope",
388        "//tensorflow/python/feature_column",
389        "//tensorflow/python/feature_column:feature_column_py",
390    ],
391)
392
393tf_py_test(
394    name = "feature_column_test",
395    srcs = [
396        "feature_column_test.py",
397    ],
398    main = "feature_column_test.py",
399    deps = [
400        ":feature_column",
401        "//tensorflow/python:client_testlib",
402        "//tensorflow/python:dtypes",
403        "//tensorflow/python:framework_ops",
404        "//tensorflow/python:lookup_ops",
405        "//tensorflow/python:parsing_ops",
406        "//tensorflow/python:session",
407        "//tensorflow/python:sparse_tensor",
408        "//tensorflow/python:variables",
409        "//tensorflow/python/feature_column",
410        "//tensorflow/python/feature_column:feature_column_py",
411        "//third_party/py/numpy",
412    ],
413)
414
415tf_py_test(
416    name = "feature_column_v2_test",
417    srcs = [
418        "feature_column_v2_test.py",
419    ],
420    main = "feature_column_v2_test.py",
421    deps = [
422        ":feature_column_v2",
423        "//tensorflow/python:client_testlib",
424        "//tensorflow/python:dtypes",
425        "//tensorflow/python:framework_ops",
426        "//tensorflow/python:lookup_ops",
427        "//tensorflow/python:parsing_ops",
428        "//tensorflow/python:session",
429        "//tensorflow/python:sparse_tensor",
430        "//tensorflow/python:variables",
431        "//tensorflow/python/feature_column",
432        "//tensorflow/python/feature_column:feature_column_py",
433        "//third_party/py/numpy",
434    ],
435)
436
437tf_proto_library(
438    name = "tensor_tracer_proto",
439    srcs = ["tensor_tracer.proto"],
440    cc_api_version = 2,
441    protodeps = [
442        "//tensorflow/core:protos_all",
443    ],
444    visibility = ["//visibility:public"],
445)
446