1load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") 2load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") 3load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") 4load("//tensorflow:tensorflow.bzl", "filegroup", "get_compatible_with_cloud") 5 6package( 7 default_visibility = ["//visibility:public"], 8 licenses = ["notice"], 9) 10 11cc_library( 12 name = "base", 13 srcs = [ 14 "dialect/Base.cpp", 15 ], 16 hdrs = [ 17 "dialect/Base.h", 18 ], 19 compatible_with = get_compatible_with_cloud(), 20 includes = ["."], 21 deps = [ 22 ":base_attr_interfaces_inc_gen", 23 "@llvm-project//llvm:Support", 24 "@llvm-project//mlir:IR", 25 "@llvm-project//mlir:InferTypeOpInterface", 26 "@llvm-project//mlir:QuantOps", 27 "@llvm-project//mlir:ShapeDialect", 28 "@llvm-project//mlir:Support", 29 ], 30) 31 32gentbl_cc_library( 33 name = "base_attr_interfaces_inc_gen", 34 compatible_with = get_compatible_with_cloud(), 35 tbl_outs = [ 36 ( 37 ["-gen-attr-interface-decls"], 38 "dialect/BaseAttrInterfaces.h.inc", 39 ), 40 ( 41 ["-gen-attr-interface-defs"], 42 "dialect/BaseAttrInterfaces.cpp.inc", 43 ), 44 ], 45 tblgen = "@llvm-project//mlir:mlir-tblgen", 46 td_file = "dialect/Base.td", 47 deps = [":stablehlo_td_files"], 48) 49 50td_library( 51 name = "base_td_files", 52 srcs = [ 53 "dialect/Base.td", 54 ], 55 compatible_with = get_compatible_with_cloud(), 56 includes = ["."], 57 deps = [ 58 "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", 59 "@llvm-project//mlir:OpBaseTdFiles", 60 "@llvm-project//mlir:QuantizationOpsTdFiles", 61 ], 62) 63 64cc_library( 65 name = "broadcast_utils", 66 srcs = [ 67 "dialect/BroadcastUtils.cpp", 68 ], 69 hdrs = [ 70 "dialect/BroadcastUtils.h", 71 ], 72 compatible_with = get_compatible_with_cloud(), 73 includes = ["."], 74 deps = [ 75 "@llvm-project//llvm:Support", 76 "@llvm-project//mlir:IR", 77 "@llvm-project//mlir:ShapeDialect", 78 ], 79) 80 81gentbl_cc_library( 82 name = "chlo_attrs_inc_gen", 83 compatible_with = get_compatible_with_cloud(), 84 tbl_outs = [ 85 ( 86 ["-gen-attrdef-decls"], 87 "dialect/ChloAttrs.h.inc", 88 ), 89 ( 90 ["-gen-attrdef-defs"], 91 "dialect/ChloAttrs.cpp.inc", 92 ), 93 ], 94 tblgen = "@llvm-project//mlir:mlir-tblgen", 95 td_file = "dialect/ChloOps.td", 96 deps = [ 97 ":chlo_td_files", 98 ], 99) 100 101gentbl_cc_library( 102 name = "chlo_enums_inc_gen", 103 compatible_with = get_compatible_with_cloud(), 104 tbl_outs = [ 105 ( 106 ["-gen-enum-decls"], 107 "dialect/ChloEnums.h.inc", 108 ), 109 ( 110 ["-gen-enum-defs"], 111 "dialect/ChloEnums.cpp.inc", 112 ), 113 ], 114 tblgen = "@llvm-project//mlir:mlir-tblgen", 115 td_file = "dialect/ChloOps.td", 116 deps = [ 117 ":chlo_td_files", 118 ], 119) 120 121gentbl_cc_library( 122 name = "chlo_ops_inc_gen", 123 compatible_with = get_compatible_with_cloud(), 124 tbl_outs = [ 125 ( 126 ["-gen-op-decls"], 127 "dialect/ChloOps.h.inc", 128 ), 129 ( 130 ["-gen-op-defs"], 131 "dialect/ChloOps.cpp.inc", 132 ), 133 ], 134 tblgen = "@llvm-project//mlir:mlir-tblgen", 135 td_file = "dialect/ChloOps.td", 136 deps = [ 137 ":chlo_td_files", 138 ], 139) 140 141td_library( 142 name = "chlo_td_files", 143 srcs = [ 144 "dialect/ChloEnums.td", 145 "dialect/ChloOps.td", 146 ], 147 compatible_with = get_compatible_with_cloud(), 148 includes = ["."], 149 deps = [ 150 ":base_td_files", 151 "@llvm-project//mlir:BuiltinDialectTdFiles", 152 "@llvm-project//mlir:ControlFlowInterfacesTdFiles", 153 "@llvm-project//mlir:OpBaseTdFiles", 154 ], 155) 156 157cc_library( 158 name = "chlo_ops", 159 srcs = [ 160 "dialect/ChloOps.cpp", 161 ], 162 hdrs = [ 163 "dialect/ChloOps.h", 164 ], 165 compatible_with = get_compatible_with_cloud(), 166 includes = ["."], 167 deps = [ 168 ":base", 169 ":broadcast_utils", 170 ":chlo_attrs_inc_gen", 171 ":chlo_enums_inc_gen", 172 ":chlo_ops_inc_gen", 173 "@llvm-project//llvm:Support", 174 "@llvm-project//mlir:ComplexDialect", 175 "@llvm-project//mlir:ControlFlowInterfaces", 176 "@llvm-project//mlir:Dialect", 177 "@llvm-project//mlir:IR", 178 "@llvm-project//mlir:InferTypeOpInterface", 179 "@llvm-project//mlir:QuantOps", 180 ], 181) 182 183cc_library( 184 name = "register", 185 srcs = [ 186 "dialect/Register.cpp", 187 ], 188 hdrs = [ 189 "dialect/Register.h", 190 ], 191 compatible_with = get_compatible_with_cloud(), 192 deps = [ 193 ":chlo_ops", 194 ":stablehlo_ops", 195 "@llvm-project//mlir:IR", 196 ], 197) 198 199gentbl_cc_library( 200 name = "stablehlo_attrs_inc_gen", 201 compatible_with = get_compatible_with_cloud(), 202 tbl_outs = [ 203 ( 204 ["-gen-attrdef-decls"], 205 "dialect/StablehloAttrs.h.inc", 206 ), 207 ( 208 ["-gen-attrdef-defs"], 209 "dialect/StablehloAttrs.cpp.inc", 210 ), 211 ], 212 tblgen = "@llvm-project//mlir:mlir-tblgen", 213 td_file = "dialect/StablehloOps.td", 214 deps = [ 215 ":stablehlo_td_files", 216 ], 217) 218 219gentbl_cc_library( 220 name = "stablehlo_enums_inc_gen", 221 compatible_with = get_compatible_with_cloud(), 222 tbl_outs = [ 223 ( 224 ["-gen-enum-decls"], 225 "dialect/StablehloEnums.h.inc", 226 ), 227 ( 228 ["-gen-enum-defs"], 229 "dialect/StablehloEnums.cpp.inc", 230 ), 231 ], 232 tblgen = "@llvm-project//mlir:mlir-tblgen", 233 td_file = "dialect/StablehloOps.td", 234 deps = [ 235 ":stablehlo_td_files", 236 ], 237) 238 239gentbl_cc_library( 240 name = "stablehlo_ops_inc_gen", 241 compatible_with = get_compatible_with_cloud(), 242 tbl_outs = [ 243 ( 244 ["-gen-op-decls"], 245 "dialect/StablehloOps.h.inc", 246 ), 247 ( 248 ["-gen-op-defs"], 249 "dialect/StablehloOps.cpp.inc", 250 ), 251 ], 252 tblgen = "@llvm-project//mlir:mlir-tblgen", 253 td_file = "dialect/StablehloOps.td", 254 deps = [ 255 ":stablehlo_td_files", 256 ], 257) 258 259td_library( 260 name = "stablehlo_td_files", 261 srcs = [ 262 "dialect/Base.td", 263 "dialect/StablehloAttrs.td", 264 "dialect/StablehloEnums.td", 265 "dialect/StablehloOps.td", 266 ], 267 compatible_with = get_compatible_with_cloud(), 268 includes = ["."], 269 deps = [ 270 ":base_td_files", 271 "@llvm-project//mlir:BuiltinDialectTdFiles", 272 "@llvm-project//mlir:OpBaseTdFiles", 273 "@llvm-project//mlir:ShapeOpsTdFiles", 274 ], 275) 276 277cc_library( 278 name = "stablehlo_ops", 279 srcs = [ 280 "dialect/StablehloOps.cpp", 281 ], 282 hdrs = [ 283 "dialect/StablehloOps.h", 284 ], 285 compatible_with = get_compatible_with_cloud(), 286 includes = ["."], 287 deps = [ 288 ":base", 289 ":stablehlo_attrs_inc_gen", 290 ":stablehlo_enums_inc_gen", 291 ":stablehlo_ops_inc_gen", 292 "@llvm-project//llvm:Support", 293 "@llvm-project//mlir:ArithmeticDialect", 294 "@llvm-project//mlir:ComplexDialect", 295 "@llvm-project//mlir:Dialect", 296 "@llvm-project//mlir:IR", 297 "@llvm-project//mlir:InferTypeOpInterface", 298 "@llvm-project//mlir:QuantOps", 299 "@llvm-project//mlir:ShapeDialect", 300 "@llvm-project//mlir:SparseTensorDialect", 301 "@llvm-project//mlir:Support", 302 "@llvm-project//mlir:TensorDialect", 303 ], 304) 305 306cc_binary( 307 name = "stablehlo-opt", 308 srcs = ["tools/StablehloOptMain.cpp"], 309 compatible_with = get_compatible_with_cloud(), 310 deps = [ 311 ":register", 312 ":test_utils", 313 "@llvm-project//mlir:AllPassesAndDialects", 314 "@llvm-project//mlir:MlirOptLib", 315 ], 316) 317 318glob_lit_tests( 319 data = [":test_data"], 320 driver = "@llvm-project//mlir:run_lit.sh", 321 test_file_exts = ["mlir"], 322) 323 324filegroup( 325 name = "test_data", 326 testonly = True, 327 data = [ 328 "//tensorflow/compiler/xla/mlir_hlo/stablehlo:stablehlo-opt", 329 "@llvm-project//llvm:FileCheck", 330 ], 331) 332 333gentbl_cc_library( 334 name = "test_utils_inc_gen", 335 compatible_with = get_compatible_with_cloud(), 336 tbl_outs = [ 337 ( 338 [ 339 "-gen-pass-decls", 340 "-name=HloTest", 341 ], 342 "tests/TestUtils.h.inc", 343 ), 344 ], 345 tblgen = "@llvm-project//mlir:mlir-tblgen", 346 td_file = "tests/TestUtils.td", 347 deps = [ 348 ":test_utils_td_files", 349 ], 350) 351 352td_library( 353 name = "test_utils_td_files", 354 srcs = [ 355 "tests/TestUtils.td", 356 ], 357 compatible_with = get_compatible_with_cloud(), 358 includes = ["."], 359 deps = [ 360 "@llvm-project//mlir:PassBaseTdFiles", 361 ], 362) 363 364cc_library( 365 name = "test_utils", 366 srcs = [ 367 "tests/TestUtils.cpp", 368 ], 369 hdrs = [ 370 "tests/TestUtils.h", 371 ], 372 compatible_with = get_compatible_with_cloud(), 373 includes = ["."], 374 deps = [ 375 ":test_utils_inc_gen", 376 "@llvm-project//llvm:Support", 377 "@llvm-project//mlir:FuncDialect", 378 "@llvm-project//mlir:IR", 379 "@llvm-project//mlir:InferTypeOpInterface", 380 "@llvm-project//mlir:Pass", 381 "@llvm-project//mlir:ShapeDialect", 382 "@llvm-project//mlir:Support", 383 "@llvm-project//mlir:Transforms", 384 ], 385) 386