1load( 2 "//tensorflow:tensorflow.bzl", 3 "tf_cc_binary", 4 "tf_cc_test", 5 "tf_py_test", 6) 7 8# buildifier: disable=same-origin-load 9load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") 10 11# buildifier: disable=same-origin-load 12load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") 13load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") 14load( 15 "@llvm-project//mlir:tblgen.bzl", 16 "gentbl_cc_library", 17 "td_library", 18) 19load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") 20load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") 21 22package( 23 default_visibility = [ 24 ":friends", 25 ], 26 licenses = ["notice"], 27) 28 29package_group( 30 name = "friends", 31 packages = [ 32 "//tensorflow/c/...", 33 "//tensorflow/compiler/...", 34 # Allow visibility from the mlir language server. 35 "//learning/brain/mlir/mlir_lsp_server/...", 36 ], 37) 38 39td_library( 40 name = "tfr_ops_td_files", 41 srcs = [ 42 "ir/tfr_ops.td", 43 ], 44 compatible_with = get_compatible_with_cloud(), 45 deps = [ 46 "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", 47 "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", 48 "@llvm-project//mlir:CallInterfacesTdFiles", 49 "@llvm-project//mlir:ControlFlowInterfacesTdFiles", 50 "@llvm-project//mlir:FunctionInterfacesTdFiles", 51 "@llvm-project//mlir:OpBaseTdFiles", 52 "@llvm-project//mlir:ShapeOpsTdFiles", 53 "@llvm-project//mlir:SideEffectInterfacesTdFiles", 54 ], 55) 56 57gentbl_cc_library( 58 name = "tfr_ops_inc_gen", 59 compatible_with = get_compatible_with_cloud(), 60 tbl_outs = [ 61 ( 62 ["-gen-op-decls"], 63 "ir/tfr_ops.h.inc", 64 ), 65 ( 66 ["-gen-op-defs"], 67 "ir/tfr_ops.cc.inc", 68 ), 69 ], 70 tblgen = "@llvm-project//mlir:mlir-tblgen", 71 td_file = "ir/tfr_ops.td", 72 deps = [ 73 ":tfr_ops_td_files", 74 ], 75) 76 77gentbl_cc_library( 78 name = "tfr_decompose_inc_gen", 79 compatible_with = get_compatible_with_cloud(), 80 tbl_outs = [ 81 ( 82 ["-gen-rewriters"], 83 "passes/generated_decompose.inc", 84 ), 85 ], 86 tblgen = "@llvm-project//mlir:mlir-tblgen", 87 td_file = "passes/decompose_patterns.td", 88 deps = [ 89 ":tfr_ops_td_files", 90 "@llvm-project//mlir:ArithmeticOpsTdFiles", 91 "@llvm-project//mlir:FuncTdFiles", 92 ], 93) 94 95cc_library( 96 name = "tfr", 97 srcs = [ 98 "ir/tfr_ops.cc", 99 "ir/tfr_ops.cc.inc", 100 ], 101 hdrs = [ 102 "ir/tfr_ops.h", 103 "ir/tfr_ops.h.inc", 104 "ir/tfr_types.h", 105 ], 106 deps = [ 107 ":tfr_ops_inc_gen", 108 "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", 109 "//tensorflow/compiler/mlir/tensorflow", 110 "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", 111 "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", 112 "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", 113 "@llvm-project//llvm:Support", 114 "@llvm-project//mlir:ArithmeticDialect", 115 "@llvm-project//mlir:ControlFlowInterfaces", 116 "@llvm-project//mlir:Dialect", 117 "@llvm-project//mlir:FuncDialect", 118 "@llvm-project//mlir:IR", 119 "@llvm-project//mlir:InferTypeOpInterface", 120 "@llvm-project//mlir:QuantOps", 121 "@llvm-project//mlir:ShapeDialect", 122 "@llvm-project//mlir:SideEffectInterfaces", 123 "@llvm-project//mlir:Support", 124 "@llvm-project//mlir:TransformUtils", 125 ], 126) 127 128cc_library( 129 name = "utils", 130 srcs = [ 131 "utils/utils.cc", 132 ], 133 hdrs = [ 134 "utils/utils.h", 135 ], 136 deps = [ 137 ":tfr", 138 "@llvm-project//llvm:Support", 139 "@llvm-project//mlir:IR", 140 "@llvm-project//mlir:Support", 141 ], 142) 143 144cc_library( 145 name = "passes", 146 srcs = [ 147 "passes/canonicalize.cc", 148 "passes/decompose.cc", 149 "passes/generated_decompose.inc", 150 "passes/raise_to_tf.cc", 151 "passes/rewrite_quantized_io.cc", 152 ], 153 hdrs = [ 154 "passes/passes.h", 155 ], 156 deps = [ 157 ":tfr", 158 ":utils", 159 "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", 160 "//tensorflow/compiler/mlir/tensorflow", 161 "//tensorflow/core:lib", 162 "@com_google_absl//absl/memory", 163 "@com_google_absl//absl/strings", 164 "@llvm-project//llvm:Support", 165 "@llvm-project//mlir:AffineUtils", 166 "@llvm-project//mlir:ArithmeticDialect", 167 "@llvm-project//mlir:FuncDialect", 168 "@llvm-project//mlir:IR", 169 "@llvm-project//mlir:Pass", 170 "@llvm-project//mlir:QuantOps", 171 "@llvm-project//mlir:SCFDialect", 172 "@llvm-project//mlir:SCFToControlFlow", 173 "@llvm-project//mlir:Support", 174 "@llvm-project//mlir:TransformUtils", 175 ], 176 alwayslink = 1, 177) 178 179tf_cc_binary( 180 name = "tfr-opt", 181 srcs = ["passes/tfr_opt.cc"], 182 deps = [ 183 ":passes", 184 ":tfr", 185 "//tensorflow/compiler/mlir:init_mlir", 186 "//tensorflow/compiler/mlir:passes", 187 "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", 188 "//tensorflow/compiler/mlir/tensorflow", 189 "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", 190 "@llvm-project//mlir:AllPassesAndDialects", 191 "@llvm-project//mlir:ArithmeticDialect", 192 "@llvm-project//mlir:FuncDialect", 193 "@llvm-project//mlir:MlirOptLib", 194 "@llvm-project//mlir:QuantOps", 195 "@llvm-project//mlir:SCFDialect", 196 "@llvm-project//mlir:ShapeDialect", 197 ], 198) 199 200glob_lit_tests( 201 data = [":test_utilities"], 202 driver = "//tensorflow/compiler/mlir:run_lit.sh", 203 test_file_exts = ["mlir"], 204) 205 206# Bundle together all of the test utilities that are used by tests. 207filegroup( 208 name = "test_utilities", 209 testonly = True, 210 data = [ 211 "//tensorflow/compiler/mlir/tfr:tfr-opt", 212 "@llvm-project//llvm:FileCheck", 213 "@llvm-project//llvm:not", 214 "@llvm-project//mlir:run_lit.sh", 215 ], 216) 217 218cc_library( 219 name = "tfr_decompose_ctx", 220 srcs = ["integration/tfr_decompose_ctx.cc"], 221 hdrs = ["integration/tfr_decompose_ctx.h"], 222 deps = [ 223 ":passes", 224 ":tfr", 225 "//tensorflow/compiler/mlir/tensorflow", 226 "//tensorflow/compiler/mlir/tensorflow:convert_attr", 227 "//tensorflow/compiler/mlir/tensorflow:convert_type", 228 "//tensorflow/compiler/mlir/tensorflow:export_graphdef", 229 "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", 230 "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", 231 "//tensorflow/core:lib", 232 "//tensorflow/core:lib_internal", 233 "//tensorflow/core:protos_all_cc", 234 "//tensorflow/stream_executor/lib", 235 "@com_google_absl//absl/strings", 236 "@llvm-project//llvm:Support", 237 "@llvm-project//mlir:ArithmeticDialect", 238 "@llvm-project//mlir:FuncDialect", 239 "@llvm-project//mlir:IR", 240 "@llvm-project//mlir:Parser", 241 "@llvm-project//mlir:Pass", 242 "@llvm-project//mlir:SCFDialect", 243 "@llvm-project//mlir:ShapeDialect", 244 "@llvm-project//mlir:Transforms", 245 ], 246) 247 248tf_cc_test( 249 name = "tfr_decompose_ctx_test", 250 srcs = ["integration/tfr_decompose_ctx_test.cc"], 251 deps = [ 252 ":tfr_decompose_ctx", 253 "//tensorflow/compiler/xla:test", 254 "//tensorflow/core:framework", 255 "//tensorflow/core:ops", 256 "//tensorflow/core:protos_all_cc", 257 "//tensorflow/core:test", 258 "//tensorflow/core:test_main", 259 "//tensorflow/stream_executor/lib", 260 "@com_google_absl//absl/types:span", 261 "@llvm-project//mlir:AllPassesAndDialects", 262 "@llvm-project//mlir:IR", 263 ], 264) 265 266cc_library( 267 name = "graph_decompose_pass", 268 srcs = ["integration/graph_decompose_pass.cc"], 269 hdrs = ["integration/graph_decompose_pass.h"], 270 deps = [ 271 ":tfr_decompose_ctx", 272 "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", 273 "//tensorflow/core:lib", 274 "//tensorflow/core/common_runtime:device_set", 275 "//tensorflow/stream_executor/lib", 276 "@llvm-project//mlir:IR", 277 ], 278 alwayslink = 1, 279) 280 281tf_py_test( 282 name = "graph_decompose_test", 283 size = "small", 284 srcs = ["integration/graph_decompose_test.py"], 285 data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], 286 python_version = "PY3", 287 srcs_version = "PY3", 288 tags = [ 289 "no_pip", 290 "no_windows", # TODO(b/170752141) 291 "nomac", # TODO(b/170752141) 292 ], 293 deps = [ 294 "//tensorflow/compiler/mlir/tfr/resources:composite_ops", 295 "//tensorflow/python/eager:def_function", 296 ], 297) 298 299cc_library( 300 name = "node_expansion_pass", 301 srcs = ["integration/node_expansion_pass.cc"], 302 hdrs = ["integration/node_expansion_pass.h"], 303 deps = [ 304 ":tfr_decompose_ctx", 305 "//tensorflow/core:lib", 306 "//tensorflow/core/common_runtime/eager:core_no_xla", 307 "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", 308 "//tensorflow/stream_executor/lib", 309 "@com_google_absl//absl/strings", 310 ], 311 alwayslink = 1, 312) 313 314tf_py_test( 315 name = "node_expansion_test", 316 size = "small", 317 srcs = ["integration/node_expansion_test.py"], 318 data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], 319 python_version = "PY3", 320 srcs_version = "PY3", 321 tags = [ 322 "no_pip", 323 "no_windows", # TODO(b/170752141) 324 "nomac", # TODO(b/170752141) 325 ], 326 deps = [ 327 "//tensorflow/compiler/mlir/tfr/resources:composite_ops", 328 ], 329) 330 331tf_python_pybind_extension( 332 name = "tfr_wrapper", 333 srcs = ["python/tfr_wrapper.cc"], 334 deps = [ 335 "//tensorflow/compiler/mlir/tensorflow", 336 "//tensorflow/compiler/mlir/tfr", 337 "//tensorflow/python:pybind11_lib", 338 "//tensorflow/python:pybind11_status", 339 "@llvm-project//llvm:Support", 340 "@llvm-project//mlir:ArithmeticDialect", 341 "@llvm-project//mlir:FuncDialect", 342 "@llvm-project//mlir:IR", 343 "@llvm-project//mlir:Parser", 344 "@llvm-project//mlir:SCFDialect", 345 "@llvm-project//mlir:ShapeDialect", 346 "@pybind11", 347 ], 348) 349 350py_library( 351 name = "composite", 352 srcs = ["python/composite.py"], 353 srcs_version = "PY3", 354) 355 356py_library( 357 name = "tfr_gen", 358 srcs = ["python/tfr_gen.py"], 359 srcs_version = "PY3", 360 deps = [ 361 "//tensorflow:tensorflow_py", # buildcleaner: keep 362 "//tensorflow/compiler/mlir/tfr:tfr_wrapper", 363 "//tensorflow/python/autograph/converters", 364 "//tensorflow/python/autograph/impl", 365 "//tensorflow/python/autograph/pyct", 366 "//tensorflow/python/autograph/pyct/static_analysis", 367 "//tensorflow/python/framework", 368 "//tensorflow/python/framework:dtypes", 369 "//tensorflow/python/framework:op_def_registry", 370 "//tensorflow/python/platform", 371 "//tensorflow/python/util", 372 "@gast_archive//:gast", 373 ], 374) 375 376tf_py_test( 377 name = "tfr_gen_test", 378 size = "small", 379 srcs = ["python/tfr_gen_test.py"], 380 python_version = "PY3", 381 srcs_version = "PY3", 382 tags = ["no_pip"], 383 deps = [ 384 ":composite", 385 ":tfr_gen", 386 "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", 387 "//tensorflow/compiler/mlir/tfr/resources:test_ops", 388 "//tensorflow/python:array_ops", 389 "//tensorflow/python:math_ops", 390 ], 391) 392 393py_library( 394 name = "op_reg_gen", 395 srcs = ["python/op_reg_gen.py"], 396 srcs_version = "PY3", 397 deps = [ 398 "//tensorflow:tensorflow_py", 399 ], 400) 401 402tf_py_test( 403 name = "op_reg_gen_test", 404 size = "small", 405 srcs = ["python/op_reg_gen_test.py"], 406 python_version = "PY3", 407 srcs_version = "PY3", 408 tags = ["no_pip"], 409 deps = [ 410 ":composite", 411 ":op_reg_gen", 412 "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", 413 ], 414) 415 416py_library( 417 name = "test_utils", 418 srcs = ["python/test_utils.py"], 419 srcs_version = "PY3", 420 deps = [ 421 "//tensorflow:tensorflow_py", 422 ], 423) 424 425gen_op_libraries( 426 name = "one_op", 427 src = "define_op_template.py", 428) 429