• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
2load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
3load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library")
4load("//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist.bzl", "internal_visibility_allowlist")
5load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
6
7package_group(
8    name = "internal_visibility_allowlist_package",
9    packages = [
10        "//tensorflow/compiler/mlir/quantization/...",
11        "//tensorflow/compiler/mlir/lite/...",
12        "//third_party/cloud_tpu/inference_converter/...",  # TPU Inference Converter V1
13    ] + internal_visibility_allowlist(),
14)
15
16package(
17    default_visibility = [
18        ":internal_visibility_allowlist_package",
19    ],
20    licenses = ["notice"],
21)
22
23py_binary(
24    name = "gen_quantized_function_library",
25    srcs = ["gen_quantized_function_library.py"],
26    deps = [
27        "@absl_py//absl:app",
28        "@absl_py//absl/flags",
29    ],
30)
31
32genrule(
33    name = "quantized_function_library",
34    srcs = [
35        "passes/quantized_function_library.mlir",
36        "passes/quantized_function_library_uniform_quantized_drq.mlir",
37        "passes/quantized_function_library_tf_drq.mlir",
38    ],
39    outs = [
40        "passes/quantized_function_library.h",
41    ],
42    cmd = "$(location gen_quantized_function_library) --output_file $(RULEDIR)/passes/quantized_function_library.h --src '$(SRCS)'",
43    compatible_with = get_compatible_with_cloud(),
44    exec_tools = ["gen_quantized_function_library"],
45)
46
47cc_library(
48    name = "pass_utils",
49    srcs = [
50        "passes/utils.cc",
51    ],
52    hdrs = [
53        "passes/utils.h",
54    ],
55    compatible_with = get_compatible_with_cloud(),
56    deps = [
57        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
58        "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
59        "//tensorflow/compiler/mlir/tensorflow",
60        "//tensorflow/compiler/mlir/tensorflow:eval_util",
61        "@llvm-project//llvm:Support",
62        "@llvm-project//mlir:IR",
63    ],
64)
65
66td_library(
67    name = "quant_td_files",
68    srcs = [
69        "passes/lift_quantizable_spots_as_functions.td",
70        "passes/lift_quantizable_spots_as_functions_drq.td",
71        "passes/optimize.td",
72        "passes/prepare_lifting.td",
73        "passes/prepare_quantize.td",
74        "passes/quantize_composite_functions.td",
75        "passes/replace_cast_hacks_with_tf_xla_ops.td",
76        "passes/tf_quant_ops.td",
77        "passes/utils.td",
78    ],
79    compatible_with = get_compatible_with_cloud(),
80    deps = [
81        "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
82        "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils_td_files",
83        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
84        "@llvm-project//mlir:ArithmeticOpsTdFiles",
85        "@llvm-project//mlir:FuncTdFiles",
86    ],
87)
88
89gentbl_cc_library(
90    name = "prepare_lifting_inc_gen",
91    compatible_with = get_compatible_with_cloud(),
92    tbl_outs = [
93        (
94            ["-gen-rewriters"],
95            "passes/prepare_lifting.inc",
96        ),
97    ],
98    tblgen = "@llvm-project//mlir:mlir-tblgen",
99    td_file = "passes/prepare_lifting.td",
100    deps = [":quant_td_files"],
101)
102
103gentbl_cc_library(
104    name = "lift_quantizable_spots_as_functions_inc_gen",
105    compatible_with = get_compatible_with_cloud(),
106    tbl_outs = [
107        (
108            ["-gen-rewriters"],
109            "passes/lift_quantizable_spots_as_functions.inc",
110        ),
111    ],
112    tblgen = "@llvm-project//mlir:mlir-tblgen",
113    td_file = "passes/lift_quantizable_spots_as_functions.td",
114    deps = [":quant_td_files"],
115)
116
117gentbl_cc_library(
118    name = "lift_quantizable_spots_as_functions_drq_inc_gen",
119    compatible_with = get_compatible_with_cloud(),
120    tbl_outs = [
121        (
122            ["-gen-rewriters"],
123            "passes/lift_quantizable_spots_as_functions_drq.inc",
124        ),
125    ],
126    tblgen = "@llvm-project//mlir:mlir-tblgen",
127    td_file = "passes/lift_quantizable_spots_as_functions_drq.td",
128    deps = [":quant_td_files"],
129)
130
131gentbl_cc_library(
132    name = "prepare_quantize_inc_gen",
133    compatible_with = get_compatible_with_cloud(),
134    tbl_outs = [
135        (
136            ["-gen-rewriters"],
137            "passes/prepare_quantize.inc",
138        ),
139    ],
140    tblgen = "@llvm-project//mlir:mlir-tblgen",
141    td_file = "passes/prepare_quantize.td",
142    deps = [":quant_td_files"],
143)
144
145gentbl_cc_library(
146    name = "quantize_composite_functions_inc_gen",
147    compatible_with = get_compatible_with_cloud(),
148    tbl_outs = [
149        (
150            ["-gen-rewriters"],
151            "passes/quantize_composite_functions.inc",
152        ),
153    ],
154    tblgen = "@llvm-project//mlir:mlir-tblgen",
155    td_file = "passes/quantize_composite_functions.td",
156    deps = [":quant_td_files"],
157)
158
159gentbl_cc_library(
160    name = "tf_quant_ops_inc_gen",
161    compatible_with = get_compatible_with_cloud(),
162    tbl_outs = [
163        (
164            ["-gen-op-decls"],
165            "passes/tf_quant_ops.h.inc",
166        ),
167        (
168            ["-gen-op-defs"],
169            "passes/tf_quant_ops.cc.inc",
170        ),
171    ],
172    tblgen = "@llvm-project//mlir:mlir-tblgen",
173    td_file = "passes/tf_quant_ops.td",
174    deps = [
175        ":quant_td_files",
176    ],
177)
178
179gentbl_cc_library(
180    name = "optimize_inc_gen",
181    compatible_with = get_compatible_with_cloud(),
182    tbl_outs = [
183        (
184            ["-gen-rewriters"],
185            "passes/optimize.inc",
186        ),
187    ],
188    tblgen = "@llvm-project//mlir:mlir-tblgen",
189    td_file = "passes/optimize.td",
190    deps = [":quant_td_files"],
191)
192
193cc_library(
194    name = "tf_quant_ops",
195    srcs = [
196        "passes/tf_quant_ops.cc",
197        "passes/tf_quant_ops.cc.inc",
198        "passes/tf_quant_ops.h.inc",
199    ],
200    hdrs = ["passes/tf_quant_ops.h"],
201    compatible_with = get_compatible_with_cloud(),
202    deps = [
203        "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
204        "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces",
205        "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects",
206        "//tensorflow/compiler/mlir/tensorflow:tensorflow_structs",
207        "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits",
208        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
209        "@llvm-project//llvm:Support",
210        "@llvm-project//mlir:ControlFlowInterfaces",
211        "@llvm-project//mlir:DerivedAttributeOpInterface",
212        "@llvm-project//mlir:Dialect",
213        "@llvm-project//mlir:FuncDialect",
214        "@llvm-project//mlir:IR",
215        "@llvm-project//mlir:InferTypeOpInterface",
216        "@llvm-project//mlir:LoopLikeInterface",
217        "@llvm-project//mlir:Parser",
218        "@llvm-project//mlir:SideEffectInterfaces",
219        "@llvm-project//mlir:Support",
220    ],
221)
222
223cc_library(
224    name = "tf_op_quant_spec",
225    srcs = [
226        "ops/tf_op_quant_spec.cc",
227    ],
228    hdrs = ["ops/tf_op_quant_spec.h"],
229    compatible_with = get_compatible_with_cloud(),
230    deps = [
231        "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
232        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
233        "//tensorflow/compiler/mlir/tensorflow",
234        "@com_google_absl//absl/container:flat_hash_set",
235        "@llvm-project//mlir:IR",
236    ],
237)
238
239gentbl_cc_library(
240    name = "replace_cast_hacks_with_tf_xla_ops_inc_gen",
241    compatible_with = get_compatible_with_cloud(),
242    tbl_outs = [
243        (
244            ["-gen-rewriters"],
245            "passes/replace_cast_hacks_with_tf_xla_ops.inc",
246        ),
247    ],
248    tblgen = "@llvm-project//mlir:mlir-tblgen",
249    td_file = "passes/replace_cast_hacks_with_tf_xla_ops.td",
250    deps = [":quant_td_files"],
251)
252
253cc_library(
254    name = "passes",
255    srcs = [
256        "passes/convert_custom_aggregation_op_to_quant_stats.cc",
257        "passes/convert_fake_quant_to_qdq.cc",
258        "passes/convert_tf_quant_ops_to_mhlo.cc",
259        "passes/insert_custom_aggregation_ops.cc",
260        "passes/insert_main_function.cc",
261        "passes/insert_quantized_functions.cc",
262        "passes/issue_ids_of_custom_aggregation_ops.cc",
263        "passes/lift_quantizable_spots_as_functions.cc",
264        "passes/lift_quantizable_spots_as_functions.inc",
265        "passes/lift_quantizable_spots_as_functions_drq.cc",
266        "passes/lift_quantizable_spots_as_functions_drq.inc",
267        "passes/optimize.cc",
268        "passes/optimize.inc",
269        "passes/post_quantize.cc",
270        "passes/prepare_lifting.cc",
271        "passes/prepare_lifting.inc",
272        "passes/prepare_quantize.cc",
273        "passes/prepare_quantize.inc",
274        "passes/prepare_quantize_drq.cc",
275        "passes/quantize.cc",
276        "passes/quantize_composite_functions.cc",
277        "passes/quantize_composite_functions.inc",
278        "passes/quantized_function_library.h",
279        "passes/replace_cast_hacks_with_tf_xla_ops.cc",
280        "passes/replace_cast_hacks_with_tf_xla_ops.inc",
281    ],
282    hdrs = [
283        "passes/passes.h",
284    ],
285    compatible_with = get_compatible_with_cloud(),
286    deps = [
287        ":pass_utils",
288        ":quantization_options_proto_cc",
289        ":tf_quant_ops",
290        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
291        "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
292        "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
293        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
294        "//tensorflow/compiler/mlir/quantization/tensorflow:tf_op_quant_spec",
295        "//tensorflow/compiler/mlir/quantization/tensorflow/utils:fake_quant_utils",
296        "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils",
297        "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_to_xla_attribute_utils",
298        "//tensorflow/compiler/mlir/tensorflow",
299        "//tensorflow/compiler/mlir/tensorflow:error_util",
300        "//tensorflow/compiler/mlir/tensorflow:mangling_util",
301        "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
302        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
303        "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
304        "//tensorflow/compiler/xla:xla_data_proto_cc",
305        "//tensorflow/compiler/xla/mlir_hlo",
306        "//tensorflow/core:framework",
307        "//tensorflow/core:protos_all_cc",
308        "//tensorflow/core/ir/importexport:convert_tensor",
309        "//tensorflow/core/ir/importexport:mangling",
310        "//tensorflow/core/platform:env",
311        "//tensorflow/core/platform:macros",
312        "//tensorflow/core/platform:path",
313        "//tensorflow/lite/kernels:padding",
314        "//tensorflow/lite/kernels/internal:quantization_util",
315        "@com_google_absl//absl/container:flat_hash_set",
316        "@com_google_absl//absl/random",
317        "@com_google_absl//absl/strings",
318        "@llvm-project//llvm:Support",
319        "@llvm-project//mlir:ArithmeticDialect",
320        "@llvm-project//mlir:FuncDialect",
321        "@llvm-project//mlir:IR",
322        "@llvm-project//mlir:Parser",
323        "@llvm-project//mlir:Pass",
324        "@llvm-project//mlir:QuantOps",
325        "@llvm-project//mlir:Support",
326        "@llvm-project//mlir:TransformUtils",
327        "@llvm-project//mlir:Transforms",
328    ],
329    alwayslink = 1,
330)
331
332tf_proto_library(
333    name = "quantization_options_proto",
334    srcs = ["quantization_options.proto"],
335    cc_api_version = 2,
336)
337
338# copybara:uncomment_begin(google-only)
339# py_proto_library(
340#     name = "quantization_options_py_pb2",
341#     api_version = 2,
342#     deps = [":quantization_options_proto"],
343# )
344# copybara:uncomment_end
345
346tf_cc_binary(
347    name = "tf-quant-opt",
348    srcs = ["passes/tf_quant_opt.cc"],
349    deps = [
350        ":passes",
351        "//tensorflow/compiler/mlir:init_mlir",
352        "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps",
353        "//tensorflow/compiler/mlir/tensorflow",
354        "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
355        "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
356        "@llvm-project//mlir:AllPassesAndDialects",
357        "@llvm-project//mlir:FuncDialect",
358        "@llvm-project//mlir:MlirOptLib",
359        "@llvm-project//mlir:QuantOps",
360        "@llvm-project//mlir:SCFDialect",
361        "@llvm-project//mlir:ShapeDialect",
362    ],
363)
364