• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load(
2    "//tensorflow:tensorflow.bzl",
3    "tf_cc_binary",
4    "tf_cc_test",
5    "tf_py_test",
6)
7
8# buildifier: disable=same-origin-load
9load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
10
11# buildifier: disable=same-origin-load
12load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
13load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries")
14load(
15    "@llvm-project//mlir:tblgen.bzl",
16    "gentbl_cc_library",
17    "td_library",
18)
19load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
20load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
21
22package(
23    default_visibility = [
24        ":friends",
25    ],
26    licenses = ["notice"],
27)
28
29package_group(
30    name = "friends",
31    packages = [
32        "//tensorflow/c/...",
33        "//tensorflow/compiler/...",
34        # Allow visibility from the mlir language server.
35        "//learning/brain/mlir/mlir_lsp_server/...",
36    ],
37)
38
39td_library(
40    name = "tfr_ops_td_files",
41    srcs = [
42        "ir/tfr_ops.td",
43    ],
44    compatible_with = get_compatible_with_cloud(),
45    deps = [
46        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles",
47        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
48        "@llvm-project//mlir:CallInterfacesTdFiles",
49        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
50        "@llvm-project//mlir:FunctionInterfacesTdFiles",
51        "@llvm-project//mlir:OpBaseTdFiles",
52        "@llvm-project//mlir:ShapeOpsTdFiles",
53        "@llvm-project//mlir:SideEffectInterfacesTdFiles",
54    ],
55)
56
57gentbl_cc_library(
58    name = "tfr_ops_inc_gen",
59    compatible_with = get_compatible_with_cloud(),
60    tbl_outs = [
61        (
62            ["-gen-op-decls"],
63            "ir/tfr_ops.h.inc",
64        ),
65        (
66            ["-gen-op-defs"],
67            "ir/tfr_ops.cc.inc",
68        ),
69    ],
70    tblgen = "@llvm-project//mlir:mlir-tblgen",
71    td_file = "ir/tfr_ops.td",
72    deps = [
73        ":tfr_ops_td_files",
74    ],
75)
76
77gentbl_cc_library(
78    name = "tfr_decompose_inc_gen",
79    compatible_with = get_compatible_with_cloud(),
80    tbl_outs = [
81        (
82            ["-gen-rewriters"],
83            "passes/generated_decompose.inc",
84        ),
85    ],
86    tblgen = "@llvm-project//mlir:mlir-tblgen",
87    td_file = "passes/decompose_patterns.td",
88    deps = [
89        ":tfr_ops_td_files",
90        "@llvm-project//mlir:ArithmeticOpsTdFiles",
91        "@llvm-project//mlir:FuncTdFiles",
92    ],
93)
94
95cc_library(
96    name = "tfr",
97    srcs = [
98        "ir/tfr_ops.cc",
99        "ir/tfr_ops.cc.inc",
100    ],
101    hdrs = [
102        "ir/tfr_ops.h",
103        "ir/tfr_ops.h.inc",
104        "ir/tfr_types.h",
105    ],
106    deps = [
107        ":tfr_ops_inc_gen",
108        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
109        "//tensorflow/compiler/mlir/tensorflow",
110        "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
111        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
112        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
113        "@llvm-project//llvm:Support",
114        "@llvm-project//mlir:ArithmeticDialect",
115        "@llvm-project//mlir:ControlFlowInterfaces",
116        "@llvm-project//mlir:Dialect",
117        "@llvm-project//mlir:FuncDialect",
118        "@llvm-project//mlir:IR",
119        "@llvm-project//mlir:InferTypeOpInterface",
120        "@llvm-project//mlir:QuantOps",
121        "@llvm-project//mlir:ShapeDialect",
122        "@llvm-project//mlir:SideEffectInterfaces",
123        "@llvm-project//mlir:Support",
124        "@llvm-project//mlir:TransformUtils",
125    ],
126)
127
128cc_library(
129    name = "utils",
130    srcs = [
131        "utils/utils.cc",
132    ],
133    hdrs = [
134        "utils/utils.h",
135    ],
136    deps = [
137        ":tfr",
138        "@llvm-project//llvm:Support",
139        "@llvm-project//mlir:IR",
140        "@llvm-project//mlir:Support",
141    ],
142)
143
144cc_library(
145    name = "passes",
146    srcs = [
147        "passes/canonicalize.cc",
148        "passes/decompose.cc",
149        "passes/generated_decompose.inc",
150        "passes/raise_to_tf.cc",
151        "passes/rewrite_quantized_io.cc",
152    ],
153    hdrs = [
154        "passes/passes.h",
155    ],
156    deps = [
157        ":tfr",
158        ":utils",
159        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
160        "//tensorflow/compiler/mlir/tensorflow",
161        "//tensorflow/core:lib",
162        "@com_google_absl//absl/memory",
163        "@com_google_absl//absl/strings",
164        "@llvm-project//llvm:Support",
165        "@llvm-project//mlir:AffineUtils",
166        "@llvm-project//mlir:ArithmeticDialect",
167        "@llvm-project//mlir:FuncDialect",
168        "@llvm-project//mlir:IR",
169        "@llvm-project//mlir:Pass",
170        "@llvm-project//mlir:QuantOps",
171        "@llvm-project//mlir:SCFDialect",
172        "@llvm-project//mlir:SCFToControlFlow",
173        "@llvm-project//mlir:Support",
174        "@llvm-project//mlir:TransformUtils",
175    ],
176    alwayslink = 1,
177)
178
179tf_cc_binary(
180    name = "tfr-opt",
181    srcs = ["passes/tfr_opt.cc"],
182    deps = [
183        ":passes",
184        ":tfr",
185        "//tensorflow/compiler/mlir:init_mlir",
186        "//tensorflow/compiler/mlir:passes",
187        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
188        "//tensorflow/compiler/mlir/tensorflow",
189        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
190        "@llvm-project//mlir:AllPassesAndDialects",
191        "@llvm-project//mlir:ArithmeticDialect",
192        "@llvm-project//mlir:FuncDialect",
193        "@llvm-project//mlir:MlirOptLib",
194        "@llvm-project//mlir:QuantOps",
195        "@llvm-project//mlir:SCFDialect",
196        "@llvm-project//mlir:ShapeDialect",
197    ],
198)
199
200glob_lit_tests(
201    data = [":test_utilities"],
202    driver = "//tensorflow/compiler/mlir:run_lit.sh",
203    test_file_exts = ["mlir"],
204)
205
206# Bundle together all of the test utilities that are used by tests.
207filegroup(
208    name = "test_utilities",
209    testonly = True,
210    data = [
211        "//tensorflow/compiler/mlir/tfr:tfr-opt",
212        "@llvm-project//llvm:FileCheck",
213        "@llvm-project//llvm:not",
214        "@llvm-project//mlir:run_lit.sh",
215    ],
216)
217
218cc_library(
219    name = "tfr_decompose_ctx",
220    srcs = ["integration/tfr_decompose_ctx.cc"],
221    hdrs = ["integration/tfr_decompose_ctx.h"],
222    deps = [
223        ":passes",
224        ":tfr",
225        "//tensorflow/compiler/mlir/tensorflow",
226        "//tensorflow/compiler/mlir/tensorflow:convert_attr",
227        "//tensorflow/compiler/mlir/tensorflow:convert_type",
228        "//tensorflow/compiler/mlir/tensorflow:export_graphdef",
229        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
230        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
231        "//tensorflow/core:lib",
232        "//tensorflow/core:lib_internal",
233        "//tensorflow/core:protos_all_cc",
234        "//tensorflow/stream_executor/lib",
235        "@com_google_absl//absl/strings",
236        "@llvm-project//llvm:Support",
237        "@llvm-project//mlir:ArithmeticDialect",
238        "@llvm-project//mlir:FuncDialect",
239        "@llvm-project//mlir:IR",
240        "@llvm-project//mlir:Parser",
241        "@llvm-project//mlir:Pass",
242        "@llvm-project//mlir:SCFDialect",
243        "@llvm-project//mlir:ShapeDialect",
244        "@llvm-project//mlir:Transforms",
245    ],
246)
247
248tf_cc_test(
249    name = "tfr_decompose_ctx_test",
250    srcs = ["integration/tfr_decompose_ctx_test.cc"],
251    deps = [
252        ":tfr_decompose_ctx",
253        "//tensorflow/compiler/xla:test",
254        "//tensorflow/core:framework",
255        "//tensorflow/core:ops",
256        "//tensorflow/core:protos_all_cc",
257        "//tensorflow/core:test",
258        "//tensorflow/core:test_main",
259        "//tensorflow/stream_executor/lib",
260        "@com_google_absl//absl/types:span",
261        "@llvm-project//mlir:AllPassesAndDialects",
262        "@llvm-project//mlir:IR",
263    ],
264)
265
266cc_library(
267    name = "graph_decompose_pass",
268    srcs = ["integration/graph_decompose_pass.cc"],
269    hdrs = ["integration/graph_decompose_pass.h"],
270    deps = [
271        ":tfr_decompose_ctx",
272        "//tensorflow/compiler/mlir:mlir_graph_optimization_pass",
273        "//tensorflow/core:lib",
274        "//tensorflow/core/common_runtime:device_set",
275        "//tensorflow/stream_executor/lib",
276        "@llvm-project//mlir:IR",
277    ],
278    alwayslink = 1,
279)
280
281tf_py_test(
282    name = "graph_decompose_test",
283    size = "small",
284    srcs = ["integration/graph_decompose_test.py"],
285    data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
286    python_version = "PY3",
287    srcs_version = "PY3",
288    tags = [
289        "no_pip",
290        "no_windows",  # TODO(b/170752141)
291        "nomac",  # TODO(b/170752141)
292    ],
293    deps = [
294        "//tensorflow/compiler/mlir/tfr/resources:composite_ops",
295        "//tensorflow/python/eager:def_function",
296    ],
297)
298
299cc_library(
300    name = "node_expansion_pass",
301    srcs = ["integration/node_expansion_pass.cc"],
302    hdrs = ["integration/node_expansion_pass.h"],
303    deps = [
304        ":tfr_decompose_ctx",
305        "//tensorflow/core:lib",
306        "//tensorflow/core/common_runtime/eager:core_no_xla",
307        "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry",
308        "//tensorflow/stream_executor/lib",
309        "@com_google_absl//absl/strings",
310    ],
311    alwayslink = 1,
312)
313
314tf_py_test(
315    name = "node_expansion_test",
316    size = "small",
317    srcs = ["integration/node_expansion_test.py"],
318    data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"],
319    python_version = "PY3",
320    srcs_version = "PY3",
321    tags = [
322        "no_pip",
323        "no_windows",  # TODO(b/170752141)
324        "nomac",  # TODO(b/170752141)
325    ],
326    deps = [
327        "//tensorflow/compiler/mlir/tfr/resources:composite_ops",
328    ],
329)
330
331tf_python_pybind_extension(
332    name = "tfr_wrapper",
333    srcs = ["python/tfr_wrapper.cc"],
334    deps = [
335        "//tensorflow/compiler/mlir/tensorflow",
336        "//tensorflow/compiler/mlir/tfr",
337        "//tensorflow/python:pybind11_lib",
338        "//tensorflow/python:pybind11_status",
339        "@llvm-project//llvm:Support",
340        "@llvm-project//mlir:ArithmeticDialect",
341        "@llvm-project//mlir:FuncDialect",
342        "@llvm-project//mlir:IR",
343        "@llvm-project//mlir:Parser",
344        "@llvm-project//mlir:SCFDialect",
345        "@llvm-project//mlir:ShapeDialect",
346        "@pybind11",
347    ],
348)
349
350py_library(
351    name = "composite",
352    srcs = ["python/composite.py"],
353    srcs_version = "PY3",
354)
355
356py_library(
357    name = "tfr_gen",
358    srcs = ["python/tfr_gen.py"],
359    srcs_version = "PY3",
360    deps = [
361        "//tensorflow:tensorflow_py",  # buildcleaner: keep
362        "//tensorflow/compiler/mlir/tfr:tfr_wrapper",
363        "//tensorflow/python/autograph/converters",
364        "//tensorflow/python/autograph/impl",
365        "//tensorflow/python/autograph/pyct",
366        "//tensorflow/python/autograph/pyct/static_analysis",
367        "//tensorflow/python/framework",
368        "//tensorflow/python/framework:dtypes",
369        "//tensorflow/python/framework:op_def_registry",
370        "//tensorflow/python/platform",
371        "//tensorflow/python/util",
372        "@gast_archive//:gast",
373    ],
374)
375
376tf_py_test(
377    name = "tfr_gen_test",
378    size = "small",
379    srcs = ["python/tfr_gen_test.py"],
380    python_version = "PY3",
381    srcs_version = "PY3",
382    tags = ["no_pip"],
383    deps = [
384        ":composite",
385        ":tfr_gen",
386        "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper",
387        "//tensorflow/compiler/mlir/tfr/resources:test_ops",
388        "//tensorflow/python:array_ops",
389        "//tensorflow/python:math_ops",
390    ],
391)
392
393py_library(
394    name = "op_reg_gen",
395    srcs = ["python/op_reg_gen.py"],
396    srcs_version = "PY3",
397    deps = [
398        "//tensorflow:tensorflow_py",
399    ],
400)
401
402tf_py_test(
403    name = "op_reg_gen_test",
404    size = "small",
405    srcs = ["python/op_reg_gen_test.py"],
406    python_version = "PY3",
407    srcs_version = "PY3",
408    tags = ["no_pip"],
409    deps = [
410        ":composite",
411        ":op_reg_gen",
412        "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper",
413    ],
414)
415
416py_library(
417    name = "test_utils",
418    srcs = ["python/test_utils.py"],
419    srcs_version = "PY3",
420    deps = [
421        "//tensorflow:tensorflow_py",
422    ],
423)
424
425gen_op_libraries(
426    name = "one_op",
427    src = "define_op_template.py",
428)
429