• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# TensorFlow -> TOSA Compiler Bridge.
2# See:
3#   https://developer.mlplatform.org/w/tosa/
4#   https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md
5
6load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
7load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
8
9# TODO: Tighten visibility once targets are at the right granularity.
10package(
11    default_visibility = [":internal"],
12    licenses = ["notice"],
13)
14
15package_group(
16    name = "internal",
17    packages = [
18        "//tensorflow/compiler/mlir/...",
19    ],
20)
21
22package_group(
23    name = "friends",
24    includes = [
25        ":internal",
26    ],
27    packages = [
28        "//third_party/iree/...",
29    ],
30)
31
32filegroup(
33    name = "tosa_ops_td_files",
34    srcs = [
35        "@llvm-project//mlir:TosaDialectTdFiles",
36    ],
37    compatible_with = get_compatible_with_cloud(),
38)
39
40gentbl_cc_library(
41    name = "tosa_passes_inc_gen",
42    compatible_with = get_compatible_with_cloud(),
43    tbl_outs = [
44        (
45            [
46                "-gen-pass-decls",
47                "-name=LegalizeTosa",
48            ],
49            "transforms/passes.h.inc",
50        ),
51    ],
52    tblgen = "@llvm-project//mlir:mlir-tblgen",
53    td_file = "transforms/passes.td",
54    deps = [
55        "@llvm-project//mlir:PassBaseTdFiles",
56    ],
57)
58
59cc_library(
60    name = "passes_header",
61    hdrs = [
62        "transforms/passes.h",
63        "transforms/passes.h.inc",
64    ],
65    compatible_with = get_compatible_with_cloud(),
66    deps = ["@llvm-project//mlir:Pass"],
67)
68
69cc_library(
70    name = "legalize_common",
71    srcs = [
72        "transforms/legalize_common.cc",
73        "transforms/legalize_utils.cc",
74    ],
75    hdrs = [
76        "transforms/legalize_common.h",
77        "transforms/legalize_utils.h",
78    ],
79    compatible_with = get_compatible_with_cloud(),
80    deps = [
81        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
82        "//tensorflow/core:framework",
83        "//tensorflow/core/kernels:conv_grad_shape_utils",
84        "@llvm-project//llvm:Support",
85        "@llvm-project//mlir:IR",
86        "@llvm-project//mlir:QuantOps",
87        "@llvm-project//mlir:Support",
88        "@llvm-project//mlir:TensorDialect",
89        "@llvm-project//mlir:TosaDialect",
90    ],
91    alwayslink = 1,
92)
93
94gentbl_cc_library(
95    name = "tosa_legalize_tf_inc_gen",
96    compatible_with = get_compatible_with_cloud(),
97    tbl_outs = [
98        (
99            ["-gen-rewriters"],
100            "transforms/tf_legalize_patterns.inc",
101        ),
102    ],
103    tblgen = "@llvm-project//mlir:mlir-tblgen",
104    td_file = "transforms/tf_legalize_patterns.td",
105    deps = [
106        ":tosa_ops_td_files",
107        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
108        "@llvm-project//mlir:StdOpsTdFiles",
109    ],
110)
111
112cc_library(
113    name = "tf_passes",
114    srcs = [
115        "tf_passes.cc",
116        "transforms/fuse_bias_tf.cc",
117        "transforms/legalize_tf.cc",
118        "transforms/tf_legalize_patterns.inc",
119    ],
120    hdrs = [
121        "tf_passes.h",
122        "transforms/passes.h",
123    ],
124    compatible_with = get_compatible_with_cloud(),
125    visibility = [":friends"],
126    deps = [
127        ":legalize_common",
128        ":passes_header",
129        "//tensorflow/compiler/mlir/tensorflow",
130        "@llvm-project//llvm:Support",
131        "@llvm-project//mlir:AffineTransforms",
132        "@llvm-project//mlir:IR",
133        "@llvm-project//mlir:Pass",
134        "@llvm-project//mlir:QuantOps",
135        "@llvm-project//mlir:StandardOps",
136        "@llvm-project//mlir:Support",
137        "@llvm-project//mlir:TosaDialect",
138        "@llvm-project//mlir:Transforms",
139    ],
140    alwayslink = 1,
141)
142
143gentbl_cc_library(
144    name = "tosa_legalize_tfl_inc_gen",
145    compatible_with = get_compatible_with_cloud(),
146    tbl_outs = [
147        (
148            ["-gen-rewriters"],
149            "transforms/tfl_legalize_patterns.inc",
150        ),
151    ],
152    tblgen = "@llvm-project//mlir:mlir-tblgen",
153    td_file = "transforms/tfl_legalize_patterns.td",
154    deps = [
155        ":tosa_ops_td_files",
156        "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
157        "@llvm-project//mlir:StdOpsTdFiles",
158    ],
159)
160
161cc_library(
162    name = "tfl_passes",
163    srcs = [
164        "tfl_passes.cc",
165        "transforms/convert_tfl_uint8.cc",
166        "transforms/legalize_tfl.cc",
167        "transforms/tfl_legalize_patterns.inc",
168    ],
169    hdrs = [
170        "tfl_passes.h",
171        "transforms/passes.h",
172    ],
173    compatible_with = get_compatible_with_cloud(),
174    visibility = [":friends"],
175    deps = [
176        ":legalize_common",
177        ":passes_header",
178        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
179        "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
180        "@llvm-project//llvm:Support",
181        "@llvm-project//mlir:AffineTransforms",
182        "@llvm-project//mlir:IR",
183        "@llvm-project//mlir:Pass",
184        "@llvm-project//mlir:QuantOps",
185        "@llvm-project//mlir:StandardOps",
186        "@llvm-project//mlir:Support",
187        "@llvm-project//mlir:TosaDialect",
188        "@llvm-project//mlir:Transforms",
189    ],
190    alwayslink = 1,
191)
192