1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for post training quantization with calibration.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.util.lazy_loader import LazyLoader 21 22# Lazy load since some of the performance benchmark skylark rules 23# break dependencies. Must use double quotes to match code internal rewrite 24# rule. 25_calibration_wrapper = LazyLoader( 26 "_calibration_wrapper", globals(), 27 "tensorflow.lite.python.optimize." 28 "tensorflow_lite_wrap_calibration_wrapper") 29 30 31class Calibrator(object): 32 """Calibrates a floating point model and then quantizes it. 33 34 This is an internal class, not a public interface. 35 """ 36 37 def __init__(self, model_content): 38 """Constructor. 39 40 Args: 41 model_content: Content of a TF-Lite Flatbuffer file. 42 43 Raises: 44 ValueError: If the calibrator was unable to open the model. 45 """ 46 if not model_content: 47 raise ValueError("`model_content` must be specified.") 48 try: 49 self._calibrator = (_calibration_wrapper.CalibrationWrapper 50 .CreateWrapperCPPFromBuffer(model_content)) 51 except Exception as e: 52 raise ValueError("Failed to parse the model: %s." % e) 53 if not self._calibrator: 54 raise ValueError("Failed to parse the model.") 55 56 def calibrate_and_quantize(self, dataset_gen): 57 """Calibrates the model with specified generator and then quantizes it. 58 59 Returns: 60 A quantized model. 61 62 Args: 63 dataset_gen: A generator that generates calibration samples. 64 """ 65 self._calibrator.Prepare() 66 for calibration_sample in dataset_gen(): 67 self._calibrator.FeedTensor(calibration_sample) 68 return self._calibrator.QuantizeModel() 69