• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
2load("//third_party/mlir:tblgen.bzl", "gentbl")
3load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary")
4load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
5
6package(
7    default_visibility = [":friends"],
8    licenses = ["notice"],  # Apache 2.0
9)
10
11package_group(
12    name = "friends",
13    includes = ["//third_party/mlir:subpackages"],
14    packages = [
15        "//babelfish/device/...",
16        "//learning/brain/experimental/dtensor/...",
17        "//learning/brain/experimental/mlir/...",
18        "//learning/brain/google/xla/kernels/...",
19        "//learning/brain/google/xla/mlir/...",
20        "//learning/deepmind/partir/...",
21        "//learning/pathways/data_parallel/tf2xla/...",
22        "//platforms/xla/...",
23        "//tensorflow/compiler/mlir/...",
24        "//tensorflow/compiler/tf2xla/...",
25        "//tensorflow/compiler/xla/...",
26        "//third_party/iree/...",
27        "//third_party/mlir_edge/...",
28    ],
29)
30
31gentbl(
32    name = "xla_legalize_tf_inc_gen",
33    compatible_with = get_compatible_with_cloud(),
34    tbl_outs = [
35        ("-gen-rewriters", "transforms/generated_legalize_tf.inc"),
36    ],
37    tblgen = "@llvm-project//mlir:mlir-tblgen",
38    td_file = "transforms/legalize_tf_patterns.td",
39    td_relative_includes = [
40        "../hlo/include",
41    ],
42    td_srcs = [
43        "//tensorflow/compiler/mlir/hlo:hlo_ops_td_files",
44        "@llvm-project//llvm:Support",
45        "@llvm-project//mlir:StdOpsTdFiles",
46        "@llvm-project//mlir:TensorOpsTdFiles",
47        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
48    ],
49)
50
51gentbl(
52    name = "xla_passes_inc_gen",
53    compatible_with = get_compatible_with_cloud(),
54    tbl_outs = [
55        ("-gen-pass-decls -name XLA", "transforms/xla_passes.h.inc"),
56    ],
57    tblgen = "@llvm-project//mlir:mlir-tblgen",
58    td_file = "transforms/xla_passes.td",
59    td_relative_includes = [
60        "../hlo/include",
61    ],
62    td_srcs = [
63        "@llvm-project//mlir:PassBaseTdFiles",
64        "//tensorflow/compiler/mlir/hlo:hlo_ops_td_files",
65        "@llvm-project//llvm:Support",
66        "@llvm-project//mlir:StdOpsTdFiles",
67        "@llvm-project//mlir:TensorOpsTdFiles",
68        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
69    ],
70)
71
72cc_library(
73    name = "xla_passes",
74    srcs = [
75        "transforms/legalize_tf_types.cc",
76        "transforms/passes_detail.h",
77        "transforms/prepare_for_export.cc",
78    ],
79    hdrs = [
80        "transforms/passes.h",
81    ],
82    deps = [
83        ":xla_passes_inc_gen",
84        "//tensorflow/compiler/mlir/hlo",
85        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
86        "@llvm-project//llvm:Support",
87        "@llvm-project//mlir:IR",
88        "@llvm-project//mlir:Pass",
89        "@llvm-project//mlir:Support",
90        "@llvm-project//mlir:TransformUtils",
91    ],
92    alwayslink = 1,
93)
94
95cc_library(
96    name = "xla_legalize_tf",
97    srcs = [
98        "transforms/generated_legalize_tf.inc",
99        "transforms/legalize_tf.cc",
100        "transforms/legalize_tf_communication.cc",
101        "transforms/legalize_tf_control_flow.cc",
102    ],
103    hdrs = [
104        "transforms/passes.h",
105    ],
106    deps = [
107        ":attribute_importer",
108        ":type_to_shape",
109        ":xla_legalize_tf_with_tf2xla",
110        "//tensorflow/compiler/mlir/hlo",
111        "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
112        "//tensorflow/compiler/mlir/hlo:convert_op_folder",
113        "//tensorflow/compiler/mlir/tensorflow",
114        "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
115        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
116        "//tensorflow/compiler/xla:shape_util",
117        "//tensorflow/compiler/xla:xla_data_proto_cc",
118        "//tensorflow/compiler/xla/client:padding",
119        "//tensorflow/compiler/xla/client:sharding_builder",
120        "//tensorflow/compiler/xla/client/lib:conv_grad_size_util",
121        "//tensorflow/core:framework",
122        "//tensorflow/core/kernels:conv_grad_shape_utils",
123        "//tensorflow/core/platform:bfloat16",
124        "@llvm-project//llvm:Support",
125        "@llvm-project//mlir:Analysis",
126        "@llvm-project//mlir:Dialect",
127        "@llvm-project//mlir:IR",
128        "@llvm-project//mlir:Pass",
129        "@llvm-project//mlir:Shape",
130        "@llvm-project//mlir:StandardOps",
131        "@llvm-project//mlir:Support",
132        "@llvm-project//mlir:TensorDialect",
133        "@llvm-project//mlir:Transforms",
134    ],
135    alwayslink = 1,
136)
137
138cc_library(
139    name = "xla_legalize_tf_with_tf2xla",
140    srcs = ["transforms/legalize_tf_with_tf2xla.cc"],
141    deps = [
142        ":mlir_hlo_builder",
143        "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
144        "//tensorflow/compiler/mlir/hlo",
145        "//tensorflow/compiler/mlir/tensorflow",
146        "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
147        "//tensorflow/compiler/mlir/tensorflow:convert_type",
148        "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
149        "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
150        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
151        "//tensorflow/compiler/mlir/tensorflow:translate_utils",
152        "//tensorflow/compiler/tf2xla:xla_compilation_device",
153        "//tensorflow/compiler/tf2xla:xla_context",
154        "//tensorflow/compiler/tf2xla:xla_expression",
155        "//tensorflow/compiler/tf2xla:xla_helpers",
156        "//tensorflow/compiler/tf2xla:xla_op_registry",
157        "//tensorflow/compiler/xla/client:xla_builder",
158        "//tensorflow/core:core_cpu_lib",
159        "//tensorflow/core:framework",
160        "//tensorflow/core:lib",
161        "//tensorflow/core:lib_internal",
162        "//tensorflow/core:protos_all_cc",
163        "//tensorflow/core:session_options",
164        "//tensorflow/stream_executor:timer",
165        "//tensorflow/stream_executor/lib",
166        "@com_google_absl//absl/container:inlined_vector",
167        "@com_google_absl//absl/memory",
168        "@com_google_absl//absl/strings",
169        "@llvm-project//llvm:Support",
170        "@llvm-project//mlir:IR",
171        "@llvm-project//mlir:Pass",
172        "@llvm-project//mlir:StandardOps",
173        "@llvm-project//mlir:Support",
174        "@llvm-project//mlir:TensorDialect",
175        "@llvm-project//mlir:TransformUtils",
176    ],
177    alwayslink = 1,
178)
179
180cc_library(
181    name = "mhlo_to_lhlo_with_xla",
182    srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"],
183    hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"],
184    deps = [
185        ":attribute_importer",
186        ":hlo_module_importer",
187        ":hlo_utils",
188        ":mlir_hlo_to_hlo",
189        ":translate_cl_options",
190        "//tensorflow/compiler/mlir/hlo",
191        "//tensorflow/compiler/mlir/hlo:hlo_ops_base_enums",
192        "//tensorflow/compiler/mlir/hlo:lhlo",
193        "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
194        "//tensorflow/compiler/xla:debug_options_flags",
195        "//tensorflow/compiler/xla:shape_util",
196        "//tensorflow/compiler/xla:statusor",
197        "//tensorflow/compiler/xla:util",
198        "//tensorflow/compiler/xla:window_util",
199        "//tensorflow/compiler/xla:xla_data_proto_cc",
200        "//tensorflow/compiler/xla/service:backend",
201        "//tensorflow/compiler/xla/service:buffer_assignment",
202        "//tensorflow/compiler/xla/service:hlo",
203        "//tensorflow/compiler/xla/service:hlo_casting_utils",
204        "//tensorflow/compiler/xla/service:hlo_parser",
205        "//tensorflow/compiler/xla/service/gpu:backend_configs_cc",
206        "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
207        "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
208        "@com_google_absl//absl/algorithm:container",
209        "@com_google_absl//absl/types:optional",
210        "@llvm-project//llvm:Support",
211        "@llvm-project//mlir:IR",
212        "@llvm-project//mlir:Pass",
213        "@llvm-project//mlir:StandardOps",
214        "@llvm-project//mlir:Translation",
215    ],
216    alwayslink = 1,
217)
218
219cc_library(
220    name = "mlir_hlo_builder",
221    srcs = ["ir/mlir_hlo_builder.cc"],
222    hdrs = ["ir/mlir_hlo_builder.h"],
223    deps = [
224        ":attribute_importer",
225        ":hlo_module_importer",
226        ":hlo_utils",
227        ":type_to_shape",
228        "//tensorflow/compiler/mlir/hlo",
229        "//tensorflow/compiler/xla:comparison_util",
230        "//tensorflow/compiler/xla:shape_util",
231        "//tensorflow/compiler/xla:types",
232        "//tensorflow/compiler/xla:util",
233        "//tensorflow/compiler/xla/client:xla_builder",
234        "//tensorflow/compiler/xla/service:hlo",
235        "//tensorflow/compiler/xla/service:shape_inference",
236        "//tensorflow/core/platform:types",
237        "//tensorflow/stream_executor/lib",
238        "@com_google_absl//absl/container:flat_hash_map",
239        "@llvm-project//llvm:Support",
240        "@llvm-project//mlir:IR",
241    ],
242)
243
244cc_library(
245    name = "hlo_utils",
246    srcs = ["hlo_utils.cc"],
247    hdrs = ["hlo_utils.h"],
248    includes = ["include"],
249    deps = [
250        "//tensorflow/compiler/mlir/hlo",
251        "//tensorflow/compiler/mlir/hlo:convert_op_folder",
252        "//tensorflow/compiler/mlir/hlo:lhlo",
253        "//tensorflow/compiler/xla:literal",
254        "//tensorflow/compiler/xla/service:hlo",
255        "//tensorflow/core:lib",
256        "@llvm-project//mlir:IR",
257    ],
258    alwayslink = 1,
259)
260
261cc_library(
262    name = "type_to_shape",
263    srcs = ["type_to_shape.cc"],
264    hdrs = ["type_to_shape.h"],
265    deps = [
266        "//tensorflow/compiler/mlir/hlo",
267        "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
268        "//tensorflow/compiler/mlir/tensorflow:convert_type",
269        "//tensorflow/compiler/xla:shape_util",
270        "//tensorflow/compiler/xla:statusor",
271        "//tensorflow/compiler/xla:xla_data_proto_cc",
272        "//tensorflow/core:framework",
273        "//tensorflow/core/platform:logging",
274        "//tensorflow/core/platform:types",
275        "@llvm-project//llvm:Support",
276        "@llvm-project//mlir:IR",
277        "@llvm-project//mlir:Support",
278    ],
279)
280
281tf_cc_test(
282    name = "type_to_shape_test",
283    srcs = ["type_to_shape_test.cc"],
284    deps = [
285        ":hlo_utils",
286        ":type_to_shape",
287        "//tensorflow/compiler/xla:shape_util",
288        "//tensorflow/compiler/xla:test",
289        "//tensorflow/compiler/xla:xla_data_proto_cc",
290        "//tensorflow/core:lib",
291        "//tensorflow/core:protos_all_cc",
292        "//tensorflow/core:test_main",
293        "@llvm-project//mlir:IR",
294    ],
295)
296
297cc_library(
298    name = "mlir_hlo_to_hlo",
299    srcs = [
300        "mlir_hlo_to_hlo.cc",
301        "operator_writers.inc",
302    ],
303    hdrs = ["mlir_hlo_to_hlo.h"],
304    deps = [
305        ":attribute_exporter",
306        ":type_to_shape",
307        ":xla_passes",
308        "//tensorflow/compiler/mlir:name_utils",
309        "//tensorflow/compiler/mlir/hlo",
310        "//tensorflow/compiler/mlir/tensorflow:convert_type",
311        "//tensorflow/compiler/mlir/tensorflow:error_util",
312        "//tensorflow/compiler/tf2xla:common",
313        "//tensorflow/compiler/tf2xla:xla_helpers",
314        "//tensorflow/compiler/xla:comparison_util",
315        "//tensorflow/compiler/xla:literal_util",
316        "//tensorflow/compiler/xla:shape_util",
317        "//tensorflow/compiler/xla:status_macros",
318        "//tensorflow/compiler/xla:xla_data_proto_cc",
319        "//tensorflow/compiler/xla/client:xla_builder",
320        "//tensorflow/compiler/xla/client/lib:matrix",
321        "//tensorflow/compiler/xla/client/lib:quantize",
322        "//tensorflow/compiler/xla/client/lib:slicing",
323        "//tensorflow/compiler/xla/service:hlo",
324        "//tensorflow/core:framework",
325        "//tensorflow/core:lib",
326        "//tensorflow/core:protos_all_cc",
327        "//tensorflow/stream_executor/lib",
328        "@llvm-project//llvm:Support",
329        "@llvm-project//mlir:Analysis",
330        "@llvm-project//mlir:IR",
331        "@llvm-project//mlir:Pass",
332        "@llvm-project//mlir:StandardOps",
333        "@llvm-project//mlir:TensorDialect",
334        "@llvm-project//mlir:TransformUtils",
335        "@llvm-project//mlir:Transforms",
336    ],
337)
338
339cc_library(
340    name = "hlo_to_mlir_hlo",
341    srcs = ["hlo_to_mlir_hlo.cc"],
342    hdrs = ["hlo_to_mlir_hlo.h"],
343    deps = [
344        ":hlo_module_importer",
345        "//tensorflow/compiler/mlir/tensorflow:error_util",
346        "//tensorflow/compiler/xla:status",
347        "//tensorflow/compiler/xla:status_macros",
348        "//tensorflow/core:lib",
349    ],
350)
351
352cc_library(
353    name = "hlo_module_importer",
354    srcs = [
355        "hlo_function_importer.cc",
356        "hlo_module_importer.cc",
357    ],
358    hdrs = [
359        "hlo_function_importer.h",
360        "hlo_module_importer.h",
361    ],
362    deps = [
363        ":attribute_importer",
364        ":hlo_utils",
365        "//tensorflow/compiler/mlir/hlo",
366        "//tensorflow/compiler/mlir/tensorflow:error_util",
367        "//tensorflow/compiler/xla:comparison_util",
368        "//tensorflow/compiler/xla:protobuf_util",
369        "//tensorflow/compiler/xla:status",
370        "//tensorflow/compiler/xla:status_macros",
371        "//tensorflow/compiler/xla:statusor",
372        "//tensorflow/compiler/xla:xla_data_proto_cc",
373        "//tensorflow/compiler/xla:xla_proto_cc",
374        "//tensorflow/compiler/xla/service:hlo",
375        "//tensorflow/compiler/xla/service:hlo_casting_utils",
376        "//tensorflow/core:lib",
377        "@com_google_absl//absl/algorithm:container",
378        "@com_google_absl//absl/types:optional",
379        "@llvm-project//llvm:Support",
380        "@llvm-project//mlir:IR",
381        "@llvm-project//mlir:StandardOps",
382    ],
383)
384
385cc_library(
386    name = "attribute_importer",
387    srcs = ["attribute_importer.cc"],
388    hdrs = ["attribute_importer.h"],
389    deps = [
390        "//tensorflow/compiler/mlir/hlo",
391        "//tensorflow/compiler/xla:statusor",
392        "//tensorflow/compiler/xla:util",
393        "//tensorflow/compiler/xla:xla_data_proto_cc",
394        "//tensorflow/core/platform:types",
395        "@llvm-project//mlir:IR",
396    ],
397)
398
399cc_library(
400    name = "attribute_exporter",
401    srcs = ["attribute_exporter.cc"],
402    hdrs = ["attribute_exporter.h"],
403    deps = [
404        "//tensorflow/compiler/mlir/hlo",
405        "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
406        "//tensorflow/compiler/xla:statusor",
407        "//tensorflow/compiler/xla:types",
408        "//tensorflow/compiler/xla:util",
409        "//tensorflow/compiler/xla:xla_data_proto_cc",
410        "//tensorflow/core/platform:types",
411        "//tensorflow/stream_executor:dnn",
412        "@llvm-project//llvm:Support",
413        "@llvm-project//mlir:IR",
414    ],
415)
416
417cc_library(
418    name = "translate_cl_options",
419    srcs = ["xla_mlir_translate_cl.cc"],
420    hdrs = ["xla_mlir_translate_cl.h"],
421    deps = [
422        "@llvm-project//llvm:Support",
423    ],
424    alwayslink = 1,
425)
426
427cc_library(
428    name = "xla_mlir_translate",
429    srcs = ["xla_mlir_translate.cc"],
430    hdrs = ["xla_mlir_translate.h"],
431    deps = [
432        ":hlo_to_mlir_hlo",
433        ":mhlo_to_lhlo_with_xla",
434        ":mlir_hlo_to_hlo",
435        ":translate_cl_options",
436        ":type_to_shape",
437        "//tensorflow/compiler/jit:xla_cpu_jit",
438        "//tensorflow/compiler/jit:xla_gpu_jit",
439        "//tensorflow/compiler/mlir/hlo",
440        "//tensorflow/compiler/xla:debug_options_flags",
441        "//tensorflow/compiler/xla:status",
442        "//tensorflow/compiler/xla:statusor",
443        "//tensorflow/compiler/xla/service:hlo_parser",
444        "//tensorflow/compiler/xla/service:hlo_proto_cc",
445        "//tensorflow/core:lib",
446        "@llvm-project//llvm:Support",
447        "@llvm-project//mlir:IR",
448        "@llvm-project//mlir:StandardOps",
449        "@llvm-project//mlir:TensorDialect",
450        "@llvm-project//mlir:Translation",
451    ],
452    alwayslink = 1,
453)
454
455tf_native_cc_binary(
456    name = "operator_writer_gen",
457    srcs = ["operator_writer_gen.cc"],
458    deps = [
459        "@llvm-project//llvm:Support",
460        "@llvm-project//llvm:TableGen",
461        "@llvm-project//mlir:Support",
462        "@llvm-project//mlir:TableGen",
463    ],
464)
465
466gentbl(
467    name = "operator_writer_inc",
468    compatible_with = get_compatible_with_cloud(),
469    tbl_outs = [("", "operator_writers.inc")],
470    tblgen = ":operator_writer_gen",
471    td_file = "//tensorflow/compiler/mlir/hlo:include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td",
472    td_relative_includes = [
473        "../hlo/include",
474    ],
475    td_srcs = [
476        "@llvm-project//mlir:include/mlir/IR/OpBase.td",
477        "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
478        "@llvm-project//mlir:SideEffectTdFiles",
479        "//tensorflow/compiler/mlir/hlo:hlo_ops_td_files",
480        # Any file in this directory is OK: this will force the current path to exist so
481        # that the relative path can be resolved.
482        "BUILD",
483    ],
484)
485
486cc_library(
487    name = "all_xla_passes_for_testing",
488    visibility = [
489        "//tensorflow/compiler/mlir:__subpackages__",
490    ],
491    deps = [
492        ":mhlo_to_lhlo_with_xla",
493        ":xla_legalize_tf",
494        ":xla_legalize_tf_with_tf2xla",
495        ":xla_passes",
496        "//tensorflow/compiler/mlir/hlo",
497        "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
498        "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
499        "//tensorflow/compiler/mlir/hlo:legalize_control_flow",
500        "//tensorflow/compiler/mlir/hlo:legalize_to_linalg",
501        "//tensorflow/compiler/mlir/hlo:legalize_to_standard",
502        "//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation",
503        "//tensorflow/compiler/mlir/hlo:lhlo",
504        "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg",
505        "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
506        "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
507        "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_parallel_loops",
508        "//tensorflow/compiler/mlir/hlo:mhlo_fusion",
509        "//tensorflow/compiler/mlir/hlo:mhlo_to_mhlo_lowering_patterns",
510        "//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow",
511        "//tensorflow/compiler/mlir/hlo:test_passes",
512        "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo",
513    ],
514)
515
516tf_cc_binary(
517    name = "xla-opt",
518    deps = [
519        ":all_xla_passes_for_testing",
520        "//tensorflow/compiler/mlir:tf_mlir_opt_main",
521        "//tensorflow/compiler/xla/service:cpu_plugin",
522        "//tensorflow/compiler/xla/service:gpu_plugin",
523    ],
524)
525