• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow:tensorflow.bzl", "py_strict_test")
2load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
3
4licenses(["notice"])
5
6py_strict_test(
7    name = "multiple_results_test",
8    srcs = ["multiple_results_test.py"],
9    python_version = "PY3",
10    tags = [
11        "no_oss",
12        "no_pip",
13    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
14    deps = [
15        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
16        "//tensorflow/python:client_testlib",
17        "//third_party/py/numpy",
18    ],
19)
20
21py_strict_test(
22    name = "tf_acos_test",
23    srcs = ["tf_acos_test.py"],
24    python_version = "PY3",
25    tags = ["no_pip"],  # TODO(b/201803253): TFRT pybindings not in OSS.
26    deps = [
27        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
28        "//tensorflow/python:client_testlib",
29        "//third_party/py/numpy",
30    ],
31)
32
33py_strict_test(
34    name = "tf_binary_bcast_test",
35    srcs = ["tf_binary_bcast_test.py"],
36    python_version = "PY3",
37    tags = [
38        "no_oss",
39        "no_pip",  # TODO(b/201803253): TFRT pybindings not in OSS.
40        "nomsan",
41    ],
42    deps = [
43        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
44        "//tensorflow/python:client_testlib",
45        "//third_party/py/numpy",
46    ],
47)
48
49py_strict_test(
50    name = "tf_broadcast_to_test",
51    srcs = ["tf_broadcast_to_test.py"],
52    python_version = "PY3",
53    tags = [
54        "no_oss",
55        "no_pip",  # TODO(b/201803253): TFRT pybindings not in OSS.
56        "nomsan",  # TODO(b/210849019)
57    ],
58    deps = [
59        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
60        "//tensorflow/python:client_testlib",
61        "//third_party/py/numpy",
62    ],
63)
64
65py_strict_test(
66    name = "tf_cast_test",
67    srcs = ["tf_cast_test.py"],
68    python_version = "PY3",
69    tags = [
70        "no_oss",
71        "no_pip",
72    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
73    deps = [
74        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
75        "//tensorflow/python:client_testlib",
76        "//third_party/py/numpy",
77    ],
78)
79
80py_strict_test(
81    name = "tf_const_test",
82    srcs = ["tf_const_test.py"],
83    python_version = "PY3",
84    tags = [
85        "no_oss",
86        "no_pip",
87    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
88    deps = [
89        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
90        "//tensorflow/python:client_testlib",
91        "//third_party/py/numpy",
92    ],
93)
94
95py_strict_test(
96    name = "tf_controlflow_test",
97    srcs = ["tf_controlflow_test.py"],
98    python_version = "PY3",
99    tags = [
100        "no_oss",
101        "no_pip",
102    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
103    deps = [
104        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
105        "//tensorflow/python:client_testlib",
106        "//third_party/py/numpy",
107    ],
108)
109
110py_strict_test(
111    name = "tf_function_test",
112    srcs = ["tf_function_test.py"],
113    python_version = "PY3",
114    tags = ["no_pip"],  # TODO(b/201803253): TFRT pybindings not in OSS.
115    deps = [
116        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
117        "//tensorflow/python:client_testlib",
118        "//third_party/py/numpy",
119    ],
120)
121
122py_strict_test(
123    name = "tf_log1p_test",
124    srcs = ["tf_log1p_test.py"],
125    python_version = "PY3",
126    tags = [
127        "no_oss",
128        "no_pip",
129    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
130    deps = [
131        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
132        "//tensorflow/python:client_testlib",
133        "//third_party/py/numpy",
134    ],
135)
136
137py_strict_test(
138    name = "tf_logical_ops_test",
139    srcs = ["tf_logical_ops_test.py"],
140    python_version = "PY3",
141    tags = ["no_pip"],  # TODO(b/201803253): TFRT pybindings not in OSS.
142    deps = [
143        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
144        "//tensorflow/python:client_testlib",
145        "//third_party/py/numpy",
146    ],
147)
148
149py_strict_test(
150    name = "tf_math_ops_test",
151    srcs = ["tf_math_ops_test.py"],
152    python_version = "PY3",
153    tags = ["no_pip"],  # TODO(b/201803253): TFRT pybindings not in OSS.
154    deps = [
155        "//tensorflow:tensorflow_py",
156        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
157        "//tensorflow/python:client_testlib",
158        "//third_party/py/numpy",
159        "@absl_py//absl/flags",
160        "@absl_py//absl/testing:parameterized",
161    ],
162)
163
164py_strict_test(
165    name = "tf_matmul_test",
166    srcs = ["tf_matmul_test.py"],
167    python_version = "PY3",
168    tags = [
169        "no_oss",
170        "no_pip",
171    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
172    deps = [
173        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
174        "//tensorflow/python:client_testlib",
175        "//third_party/py/numpy",
176    ],
177)
178
179py_strict_test(
180    name = "tf_mean_test",
181    srcs = ["tf_mean_test.py"],
182    python_version = "PY3",
183    tags = [
184        "no_oss",
185        "no_pip",
186    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
187    deps = [
188        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
189        "//tensorflow/python:client_testlib",
190        "//third_party/py/numpy",
191    ],
192)
193
194py_strict_test(
195    name = "tf_metadata_ops_test",
196    srcs = ["tf_metadata_ops_test.py"],
197    python_version = "PY3",
198    tags = [
199        "no_oss",
200        "no_pip",
201    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
202    deps = [
203        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
204        "//tensorflow/python:client_testlib",
205        "//third_party/py/numpy",
206    ],
207)
208
209py_strict_test(
210    name = "tf_pack_test",
211    srcs = ["tf_pack_test.py"],
212    python_version = "PY3",
213    tags = [
214        "no_oss",
215        "no_pip",
216    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
217    deps = [
218        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
219        "//tensorflow/python:client_testlib",
220        "//third_party/py/numpy",
221    ],
222)
223
224py_strict_test(
225    name = "tf_reshape_test",
226    srcs = ["tf_reshape_test.py"],
227    python_version = "PY3",
228    tags = [
229        "no_oss",
230        "no_pip",
231    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
232    deps = [
233        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
234        "//tensorflow/python:client_testlib",
235        "//third_party/py/numpy",
236    ],
237)
238
239py_strict_test(
240    name = "tf_select_test",
241    srcs = ["tf_select_test.py"],
242    python_version = "PY3",
243    tags = ["no_pip"],  # TODO(b/201803253): TFRT pybindings not in OSS.
244    deps = [
245        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
246        "//tensorflow/python:client_testlib",
247        "//third_party/py/numpy",
248    ],
249)
250
251py_strict_test(
252    name = "tf_strided_slice_test",
253    srcs = ["tf_strided_slice_test.py"],
254    python_version = "PY3",
255    tags = [
256        "no_oss",
257        "no_pip",
258    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
259    deps = [
260        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
261        "//tensorflow/python:client_testlib",
262        "//third_party/py/numpy",
263    ],
264)
265
266py_strict_test(
267    name = "tf_transpose_test",
268    srcs = ["tf_transpose_test.py"],
269    python_version = "PY3",
270    tags = [
271        "no_oss",
272        "no_pip",
273    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
274    deps = [
275        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
276        "//tensorflow/python:client_testlib",
277        "//third_party/py/numpy",
278    ],
279)
280
281py_strict_test(
282    name = "tf_reduction_test",
283    srcs = ["tf_reduction_test.py"],
284    python_version = "PY3",
285    tags = [
286        "no_oss",
287        "no_pip",
288    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
289    deps = [
290        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
291        "//tensorflow/python:client_testlib",
292        "//third_party/py/numpy",
293    ],
294)
295
296py_strict_test(
297    name = "tf_softmax_test",
298    srcs = ["tf_softmax_test.py"],
299    python_version = "PY3",
300    tags = [
301        "no_oss",
302        "no_pip",
303    ],  # TODO(b/201803253): TFRT pybindings not in OSS.
304    deps = [
305        "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt",
306        "//tensorflow/python:client_testlib",
307        "//third_party/py/numpy",
308    ],
309)
310
311td_library(
312    name = "python_test_attrs_td_files",
313    srcs = ["python_test_attrs.td"],
314    deps = [
315        "@llvm-project//mlir:OpBaseTdFiles",
316    ],
317)
318
319gentbl_cc_library(
320    name = "python_test_attrs_inc_gen",
321    tbl_outs = [
322        (
323            ["-gen-dialect-decls"],
324            "python_test_attrs.h.inc",
325        ),
326        (
327            ["-gen-dialect-defs"],
328            "python_test_attrs.cc.inc",
329        ),
330    ],
331    tblgen = "@llvm-project//mlir:mlir-tblgen",
332    td_file = "python_test_attrs.td",
333    deps = [":python_test_attrs_td_files"],
334)
335
336cc_library(
337    name = "python_test_attrs",
338    srcs = [
339        "python_test_attrs.cc",
340    ],
341    hdrs = [
342        "python_test_attrs.h",
343    ],
344    deps = [
345        ":python_test_attrs_inc_gen",
346        "@llvm-project//mlir:IR",
347        "@llvm-project//mlir:Support",
348    ],
349)
350
351cc_library(
352    name = "python_test_attrs_registration",
353    srcs = ["python_test_attrs_registration.cc"],
354    hdrs = ["python_test_attrs_registration.h"],
355    visibility = ["//visibility:public"],
356    deps = [
357        ":python_test_attrs",
358        "@llvm-project//mlir:IR",
359    ],
360)
361