• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
2load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
4load("//tensorflow:tensorflow.bzl", "filegroup", "get_compatible_with_cloud")
5
6package(
7    default_visibility = ["//visibility:public"],
8    licenses = ["notice"],
9)
10
11cc_library(
12    name = "base",
13    srcs = [
14        "dialect/Base.cpp",
15    ],
16    hdrs = [
17        "dialect/Base.h",
18    ],
19    compatible_with = get_compatible_with_cloud(),
20    includes = ["."],
21    deps = [
22        ":base_attr_interfaces_inc_gen",
23        "@llvm-project//llvm:Support",
24        "@llvm-project//mlir:IR",
25        "@llvm-project//mlir:InferTypeOpInterface",
26        "@llvm-project//mlir:QuantOps",
27        "@llvm-project//mlir:ShapeDialect",
28        "@llvm-project//mlir:Support",
29    ],
30)
31
32gentbl_cc_library(
33    name = "base_attr_interfaces_inc_gen",
34    compatible_with = get_compatible_with_cloud(),
35    tbl_outs = [
36        (
37            ["-gen-attr-interface-decls"],
38            "dialect/BaseAttrInterfaces.h.inc",
39        ),
40        (
41            ["-gen-attr-interface-defs"],
42            "dialect/BaseAttrInterfaces.cpp.inc",
43        ),
44    ],
45    tblgen = "@llvm-project//mlir:mlir-tblgen",
46    td_file = "dialect/Base.td",
47    deps = [":stablehlo_td_files"],
48)
49
50td_library(
51    name = "base_td_files",
52    srcs = [
53        "dialect/Base.td",
54    ],
55    compatible_with = get_compatible_with_cloud(),
56    includes = ["."],
57    deps = [
58        "@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
59        "@llvm-project//mlir:OpBaseTdFiles",
60        "@llvm-project//mlir:QuantizationOpsTdFiles",
61    ],
62)
63
64cc_library(
65    name = "broadcast_utils",
66    srcs = [
67        "dialect/BroadcastUtils.cpp",
68    ],
69    hdrs = [
70        "dialect/BroadcastUtils.h",
71    ],
72    compatible_with = get_compatible_with_cloud(),
73    includes = ["."],
74    deps = [
75        "@llvm-project//llvm:Support",
76        "@llvm-project//mlir:IR",
77        "@llvm-project//mlir:ShapeDialect",
78    ],
79)
80
81gentbl_cc_library(
82    name = "chlo_attrs_inc_gen",
83    compatible_with = get_compatible_with_cloud(),
84    tbl_outs = [
85        (
86            ["-gen-attrdef-decls"],
87            "dialect/ChloAttrs.h.inc",
88        ),
89        (
90            ["-gen-attrdef-defs"],
91            "dialect/ChloAttrs.cpp.inc",
92        ),
93    ],
94    tblgen = "@llvm-project//mlir:mlir-tblgen",
95    td_file = "dialect/ChloOps.td",
96    deps = [
97        ":chlo_td_files",
98    ],
99)
100
101gentbl_cc_library(
102    name = "chlo_enums_inc_gen",
103    compatible_with = get_compatible_with_cloud(),
104    tbl_outs = [
105        (
106            ["-gen-enum-decls"],
107            "dialect/ChloEnums.h.inc",
108        ),
109        (
110            ["-gen-enum-defs"],
111            "dialect/ChloEnums.cpp.inc",
112        ),
113    ],
114    tblgen = "@llvm-project//mlir:mlir-tblgen",
115    td_file = "dialect/ChloOps.td",
116    deps = [
117        ":chlo_td_files",
118    ],
119)
120
121gentbl_cc_library(
122    name = "chlo_ops_inc_gen",
123    compatible_with = get_compatible_with_cloud(),
124    tbl_outs = [
125        (
126            ["-gen-op-decls"],
127            "dialect/ChloOps.h.inc",
128        ),
129        (
130            ["-gen-op-defs"],
131            "dialect/ChloOps.cpp.inc",
132        ),
133    ],
134    tblgen = "@llvm-project//mlir:mlir-tblgen",
135    td_file = "dialect/ChloOps.td",
136    deps = [
137        ":chlo_td_files",
138    ],
139)
140
141td_library(
142    name = "chlo_td_files",
143    srcs = [
144        "dialect/ChloEnums.td",
145        "dialect/ChloOps.td",
146    ],
147    compatible_with = get_compatible_with_cloud(),
148    includes = ["."],
149    deps = [
150        ":base_td_files",
151        "@llvm-project//mlir:BuiltinDialectTdFiles",
152        "@llvm-project//mlir:ControlFlowInterfacesTdFiles",
153        "@llvm-project//mlir:OpBaseTdFiles",
154    ],
155)
156
157cc_library(
158    name = "chlo_ops",
159    srcs = [
160        "dialect/ChloOps.cpp",
161    ],
162    hdrs = [
163        "dialect/ChloOps.h",
164    ],
165    compatible_with = get_compatible_with_cloud(),
166    includes = ["."],
167    deps = [
168        ":base",
169        ":broadcast_utils",
170        ":chlo_attrs_inc_gen",
171        ":chlo_enums_inc_gen",
172        ":chlo_ops_inc_gen",
173        "@llvm-project//llvm:Support",
174        "@llvm-project//mlir:ComplexDialect",
175        "@llvm-project//mlir:ControlFlowInterfaces",
176        "@llvm-project//mlir:Dialect",
177        "@llvm-project//mlir:IR",
178        "@llvm-project//mlir:InferTypeOpInterface",
179        "@llvm-project//mlir:QuantOps",
180    ],
181)
182
183cc_library(
184    name = "register",
185    srcs = [
186        "dialect/Register.cpp",
187    ],
188    hdrs = [
189        "dialect/Register.h",
190    ],
191    compatible_with = get_compatible_with_cloud(),
192    deps = [
193        ":chlo_ops",
194        ":stablehlo_ops",
195        "@llvm-project//mlir:IR",
196    ],
197)
198
199gentbl_cc_library(
200    name = "stablehlo_attrs_inc_gen",
201    compatible_with = get_compatible_with_cloud(),
202    tbl_outs = [
203        (
204            ["-gen-attrdef-decls"],
205            "dialect/StablehloAttrs.h.inc",
206        ),
207        (
208            ["-gen-attrdef-defs"],
209            "dialect/StablehloAttrs.cpp.inc",
210        ),
211    ],
212    tblgen = "@llvm-project//mlir:mlir-tblgen",
213    td_file = "dialect/StablehloOps.td",
214    deps = [
215        ":stablehlo_td_files",
216    ],
217)
218
219gentbl_cc_library(
220    name = "stablehlo_enums_inc_gen",
221    compatible_with = get_compatible_with_cloud(),
222    tbl_outs = [
223        (
224            ["-gen-enum-decls"],
225            "dialect/StablehloEnums.h.inc",
226        ),
227        (
228            ["-gen-enum-defs"],
229            "dialect/StablehloEnums.cpp.inc",
230        ),
231    ],
232    tblgen = "@llvm-project//mlir:mlir-tblgen",
233    td_file = "dialect/StablehloOps.td",
234    deps = [
235        ":stablehlo_td_files",
236    ],
237)
238
239gentbl_cc_library(
240    name = "stablehlo_ops_inc_gen",
241    compatible_with = get_compatible_with_cloud(),
242    tbl_outs = [
243        (
244            ["-gen-op-decls"],
245            "dialect/StablehloOps.h.inc",
246        ),
247        (
248            ["-gen-op-defs"],
249            "dialect/StablehloOps.cpp.inc",
250        ),
251    ],
252    tblgen = "@llvm-project//mlir:mlir-tblgen",
253    td_file = "dialect/StablehloOps.td",
254    deps = [
255        ":stablehlo_td_files",
256    ],
257)
258
259td_library(
260    name = "stablehlo_td_files",
261    srcs = [
262        "dialect/Base.td",
263        "dialect/StablehloAttrs.td",
264        "dialect/StablehloEnums.td",
265        "dialect/StablehloOps.td",
266    ],
267    compatible_with = get_compatible_with_cloud(),
268    includes = ["."],
269    deps = [
270        ":base_td_files",
271        "@llvm-project//mlir:BuiltinDialectTdFiles",
272        "@llvm-project//mlir:OpBaseTdFiles",
273        "@llvm-project//mlir:ShapeOpsTdFiles",
274    ],
275)
276
277cc_library(
278    name = "stablehlo_ops",
279    srcs = [
280        "dialect/StablehloOps.cpp",
281    ],
282    hdrs = [
283        "dialect/StablehloOps.h",
284    ],
285    compatible_with = get_compatible_with_cloud(),
286    includes = ["."],
287    deps = [
288        ":base",
289        ":stablehlo_attrs_inc_gen",
290        ":stablehlo_enums_inc_gen",
291        ":stablehlo_ops_inc_gen",
292        "@llvm-project//llvm:Support",
293        "@llvm-project//mlir:ArithmeticDialect",
294        "@llvm-project//mlir:ComplexDialect",
295        "@llvm-project//mlir:Dialect",
296        "@llvm-project//mlir:IR",
297        "@llvm-project//mlir:InferTypeOpInterface",
298        "@llvm-project//mlir:QuantOps",
299        "@llvm-project//mlir:ShapeDialect",
300        "@llvm-project//mlir:SparseTensorDialect",
301        "@llvm-project//mlir:Support",
302        "@llvm-project//mlir:TensorDialect",
303    ],
304)
305
306cc_binary(
307    name = "stablehlo-opt",
308    srcs = ["tools/StablehloOptMain.cpp"],
309    compatible_with = get_compatible_with_cloud(),
310    deps = [
311        ":register",
312        ":test_utils",
313        "@llvm-project//mlir:AllPassesAndDialects",
314        "@llvm-project//mlir:MlirOptLib",
315    ],
316)
317
318glob_lit_tests(
319    data = [":test_data"],
320    driver = "@llvm-project//mlir:run_lit.sh",
321    test_file_exts = ["mlir"],
322)
323
324filegroup(
325    name = "test_data",
326    testonly = True,
327    data = [
328        "//tensorflow/compiler/xla/mlir_hlo/stablehlo:stablehlo-opt",
329        "@llvm-project//llvm:FileCheck",
330    ],
331)
332
333gentbl_cc_library(
334    name = "test_utils_inc_gen",
335    compatible_with = get_compatible_with_cloud(),
336    tbl_outs = [
337        (
338            [
339                "-gen-pass-decls",
340                "-name=HloTest",
341            ],
342            "tests/TestUtils.h.inc",
343        ),
344    ],
345    tblgen = "@llvm-project//mlir:mlir-tblgen",
346    td_file = "tests/TestUtils.td",
347    deps = [
348        ":test_utils_td_files",
349    ],
350)
351
352td_library(
353    name = "test_utils_td_files",
354    srcs = [
355        "tests/TestUtils.td",
356    ],
357    compatible_with = get_compatible_with_cloud(),
358    includes = ["."],
359    deps = [
360        "@llvm-project//mlir:PassBaseTdFiles",
361    ],
362)
363
364cc_library(
365    name = "test_utils",
366    srcs = [
367        "tests/TestUtils.cpp",
368    ],
369    hdrs = [
370        "tests/TestUtils.h",
371    ],
372    compatible_with = get_compatible_with_cloud(),
373    includes = ["."],
374    deps = [
375        ":test_utils_inc_gen",
376        "@llvm-project//llvm:Support",
377        "@llvm-project//mlir:FuncDialect",
378        "@llvm-project//mlir:IR",
379        "@llvm-project//mlir:InferTypeOpInterface",
380        "@llvm-project//mlir:Pass",
381        "@llvm-project//mlir:ShapeDialect",
382        "@llvm-project//mlir:Support",
383        "@llvm-project//mlir:Transforms",
384    ],
385)
386