1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for modify_model_interface_lib.py.""" 16 17import os 18import numpy as np 19import tensorflow as tf 20 21from tensorflow.lite.tools.optimize.python import modify_model_interface_lib 22from tensorflow.python.framework import test_util 23from tensorflow.python.platform import test 24 25 26def build_tflite_model_with_full_integer_quantization( 27 supported_ops=tf.lite.OpsSet.TFLITE_BUILTINS_INT8): 28 # Define TF model 29 input_size = 3 30 model = tf.keras.Sequential([ 31 tf.keras.layers.InputLayer(input_shape=(input_size,), dtype=tf.float32), 32 tf.keras.layers.Dense(units=5, activation=tf.nn.relu), 33 tf.keras.layers.Dense(units=2, activation=tf.nn.softmax) 34 ]) 35 36 # Convert TF Model to a Quantized TFLite Model 37 converter = tf.lite.TFLiteConverter.from_keras_model(model) 38 converter.optimizations = [tf.lite.Optimize.DEFAULT] 39 40 def representative_dataset_gen(): 41 for i in range(10): 42 yield [np.array([i] * input_size, dtype=np.float32)] 43 44 converter.representative_dataset = representative_dataset_gen 45 converter.target_spec.supported_ops = [supported_ops] 46 tflite_model = converter.convert() 47 48 return tflite_model 49 50 51class ModifyModelInterfaceTest(test_util.TensorFlowTestCase): 52 53 def testInt8Interface(self): 54 # 1. SETUP 55 # Define the temporary directory and files 56 temp_dir = self.get_temp_dir() 57 initial_file = os.path.join(temp_dir, 'initial_model.tflite') 58 final_file = os.path.join(temp_dir, 'final_model.tflite') 59 # Define initial model 60 initial_model = build_tflite_model_with_full_integer_quantization() 61 with open(initial_file, 'wb') as model_file: 62 model_file.write(initial_model) 63 64 # 2. INVOKE 65 # Invoke the modify_model_interface function 66 modify_model_interface_lib.modify_model_interface(initial_file, final_file, 67 tf.int8, tf.int8) 68 69 # 3. VALIDATE 70 # Load TFLite model and allocate tensors. 71 initial_interpreter = tf.lite.Interpreter(model_path=initial_file) 72 initial_interpreter.allocate_tensors() 73 final_interpreter = tf.lite.Interpreter(model_path=final_file) 74 final_interpreter.allocate_tensors() 75 76 # Get input and output types. 77 initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype'] 78 initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype'] 79 final_input_dtype = final_interpreter.get_input_details()[0]['dtype'] 80 final_output_dtype = final_interpreter.get_output_details()[0]['dtype'] 81 82 # Validate the model interfaces 83 self.assertEqual(initial_input_dtype, np.float32) 84 self.assertEqual(initial_output_dtype, np.float32) 85 self.assertEqual(final_input_dtype, np.int8) 86 self.assertEqual(final_output_dtype, np.int8) 87 88 def testInt16Interface(self): 89 # 1. SETUP 90 # Define the temporary directory and files 91 temp_dir = self.get_temp_dir() 92 initial_file = os.path.join(temp_dir, 'initial_model.tflite') 93 final_file = os.path.join(temp_dir, 'final_model.tflite') 94 # Define initial model 95 initial_model = build_tflite_model_with_full_integer_quantization( 96 supported_ops=tf.lite.OpsSet 97 .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8) 98 with open(initial_file, 'wb') as model_file: 99 model_file.write(initial_model) 100 101 # 2. INVOKE 102 # Invoke the modify_model_interface function 103 modify_model_interface_lib.modify_model_interface(initial_file, final_file, 104 tf.int16, tf.int16) 105 106 # 3. VALIDATE 107 # Load TFLite model and allocate tensors. 108 initial_interpreter = tf.lite.Interpreter(model_path=initial_file) 109 initial_interpreter.allocate_tensors() 110 final_interpreter = tf.lite.Interpreter(model_path=final_file) 111 final_interpreter.allocate_tensors() 112 113 # Get input and output types. 114 initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype'] 115 initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype'] 116 final_input_dtype = final_interpreter.get_input_details()[0]['dtype'] 117 final_output_dtype = final_interpreter.get_output_details()[0]['dtype'] 118 119 # Validate the model interfaces 120 self.assertEqual(initial_input_dtype, np.float32) 121 self.assertEqual(initial_output_dtype, np.float32) 122 self.assertEqual(final_input_dtype, np.int16) 123 self.assertEqual(final_output_dtype, np.int16) 124 125 def testUInt8Interface(self): 126 # 1. SETUP 127 # Define the temporary directory and files 128 temp_dir = self.get_temp_dir() 129 initial_file = os.path.join(temp_dir, 'initial_model.tflite') 130 final_file = os.path.join(temp_dir, 'final_model.tflite') 131 # Define initial model 132 initial_model = build_tflite_model_with_full_integer_quantization() 133 with open(initial_file, 'wb') as model_file: 134 model_file.write(initial_model) 135 136 # 2. INVOKE 137 # Invoke the modify_model_interface function 138 modify_model_interface_lib.modify_model_interface(initial_file, final_file, 139 tf.uint8, tf.uint8) 140 141 # 3. VALIDATE 142 # Load TFLite model and allocate tensors. 143 initial_interpreter = tf.lite.Interpreter(model_path=initial_file) 144 initial_interpreter.allocate_tensors() 145 final_interpreter = tf.lite.Interpreter(model_path=final_file) 146 final_interpreter.allocate_tensors() 147 148 # Get input and output types. 149 initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype'] 150 initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype'] 151 final_input_dtype = final_interpreter.get_input_details()[0]['dtype'] 152 final_output_dtype = final_interpreter.get_output_details()[0]['dtype'] 153 154 # Validate the model interfaces 155 self.assertEqual(initial_input_dtype, np.float32) 156 self.assertEqual(initial_output_dtype, np.float32) 157 self.assertEqual(final_input_dtype, np.uint8) 158 self.assertEqual(final_output_dtype, np.uint8) 159 160 161if __name__ == '__main__': 162 test.main() 163