1load("//tensorflow:tensorflow.bzl", "pybind_extension") 2 3package( 4 default_visibility = ["//visibility:public"], 5 licenses = ["notice"], 6) 7 8cc_library( 9 name = "calibration_wrapper_lib", 10 srcs = ["calibration_wrapper.cc"], 11 hdrs = ["calibration_wrapper.h"], 12 deps = [ 13 "//tensorflow/lite:framework", 14 "//tensorflow/lite:shared_library", 15 "//tensorflow/lite/c:common", 16 "//tensorflow/lite/kernels:builtin_ops", 17 "//tensorflow/lite/python/interpreter_wrapper:numpy", 18 "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", 19 "//tensorflow/lite/python/interpreter_wrapper:python_utils", 20 "//tensorflow/lite/tools/optimize:quantization_wrapper_utils", 21 "//tensorflow/lite/tools/optimize:quantize_model", 22 "//tensorflow/lite/tools/optimize/calibration:calibration_reader", 23 "//tensorflow/lite/tools/optimize/calibration:calibrator_lib", 24 "//third_party/python_runtime:headers", # buildcleaner: keep 25 "@com_google_absl//absl/memory", 26 "@com_google_absl//absl/strings:str_format", 27 "@com_google_absl//absl/types:optional", 28 ], 29) 30 31pybind_extension( 32 name = "_pywrap_tensorflow_lite_calibration_wrapper", 33 srcs = [ 34 "calibration_wrapper_pybind11.cc", 35 ], 36 hdrs = ["calibration_wrapper.h"], 37 link_in_framework = True, 38 deps = [ 39 ":calibration_wrapper_lib", 40 "//tensorflow/lite:framework_lib", 41 "//tensorflow/python:pybind11_lib", 42 "//third_party/python_runtime:headers", 43 "@pybind11", 44 ], 45) 46 47py_library( 48 name = "calibrator", 49 srcs = [ 50 "calibrator.py", 51 ], 52 srcs_version = "PY3", 53 visibility = ["//visibility:public"], 54 deps = [ 55 ":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep 56 "//tensorflow/lite/python:convert_phase", 57 "//tensorflow/lite/python:interpreter", 58 "//tensorflow/python:dtypes", 59 "//tensorflow/python:util", 60 "//third_party/py/numpy", 61 ], 62) 63 64py_test( 65 name = "calibrator_test", 66 srcs = ["calibrator_test.py"], 67 data = [ 68 ":test_data", 69 "//tensorflow/lite:testdata/multi_add.bin", 70 ], 71 python_version = "PY3", 72 srcs_version = "PY3", 73 tags = ["no_oss"], 74 deps = [ 75 ":calibrator", 76 "//tensorflow/python:client_testlib", 77 "//tensorflow/python:dtypes", 78 "//tensorflow/python:framework_test_lib", 79 "//tensorflow/python:platform", 80 "//third_party/py/numpy", 81 "@absl_py//absl/testing:parameterized", 82 ], 83) 84