• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1load("//tensorflow:tensorflow.bzl", "pybind_extension")
2
3package(
4    default_visibility = ["//visibility:public"],
5    licenses = ["notice"],
6)
7
8cc_library(
9    name = "calibration_wrapper_lib",
10    srcs = ["calibration_wrapper.cc"],
11    hdrs = ["calibration_wrapper.h"],
12    deps = [
13        "//tensorflow/lite:framework",
14        "//tensorflow/lite:shared_library",
15        "//tensorflow/lite/c:common",
16        "//tensorflow/lite/kernels:builtin_ops",
17        "//tensorflow/lite/python/interpreter_wrapper:numpy",
18        "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
19        "//tensorflow/lite/python/interpreter_wrapper:python_utils",
20        "//tensorflow/lite/tools/optimize:quantization_wrapper_utils",
21        "//tensorflow/lite/tools/optimize:quantize_model",
22        "//tensorflow/lite/tools/optimize/calibration:calibration_reader",
23        "//tensorflow/lite/tools/optimize/calibration:calibrator_lib",
24        "//third_party/python_runtime:headers",  # buildcleaner: keep
25        "@com_google_absl//absl/memory",
26        "@com_google_absl//absl/strings:str_format",
27        "@com_google_absl//absl/types:optional",
28    ],
29)
30
31pybind_extension(
32    name = "_pywrap_tensorflow_lite_calibration_wrapper",
33    srcs = [
34        "calibration_wrapper_pybind11.cc",
35    ],
36    hdrs = ["calibration_wrapper.h"],
37    link_in_framework = True,
38    deps = [
39        ":calibration_wrapper_lib",
40        "//tensorflow/lite:framework_lib",
41        "//tensorflow/python:pybind11_lib",
42        "//third_party/python_runtime:headers",
43        "@pybind11",
44    ],
45)
46
47py_library(
48    name = "calibrator",
49    srcs = [
50        "calibrator.py",
51    ],
52    srcs_version = "PY3",
53    visibility = ["//visibility:public"],
54    deps = [
55        ":_pywrap_tensorflow_lite_calibration_wrapper",  # buildcleaner: keep
56        "//tensorflow/lite/python:convert_phase",
57        "//tensorflow/lite/python:interpreter",
58        "//tensorflow/python:dtypes",
59        "//tensorflow/python:util",
60        "//third_party/py/numpy",
61    ],
62)
63
64py_test(
65    name = "calibrator_test",
66    srcs = ["calibrator_test.py"],
67    data = [
68        ":test_data",
69        "//tensorflow/lite:testdata/multi_add.bin",
70    ],
71    python_version = "PY3",
72    srcs_version = "PY3",
73    tags = ["no_oss"],
74    deps = [
75        ":calibrator",
76        "//tensorflow/python:client_testlib",
77        "//tensorflow/python:dtypes",
78        "//tensorflow/python:framework_test_lib",
79        "//tensorflow/python:platform",
80        "//third_party/py/numpy",
81        "@absl_py//absl/testing:parameterized",
82    ],
83)
84