• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for modify_model_interface_lib.py."""
16
17import os
18import numpy as np
19import tensorflow as tf
20
21from tensorflow.lite.tools.optimize.python import modify_model_interface_lib
22from tensorflow.python.framework import test_util
23from tensorflow.python.platform import test
24
25
26def build_tflite_model_with_full_integer_quantization(
27    supported_ops=tf.lite.OpsSet.TFLITE_BUILTINS_INT8):
28  # Define TF model
29  input_size = 3
30  model = tf.keras.Sequential([
31      tf.keras.layers.InputLayer(input_shape=(input_size,), dtype=tf.float32),
32      tf.keras.layers.Dense(units=5, activation=tf.nn.relu),
33      tf.keras.layers.Dense(units=2, activation=tf.nn.softmax)
34  ])
35
36  # Convert TF Model to a Quantized TFLite Model
37  converter = tf.lite.TFLiteConverter.from_keras_model(model)
38  converter.optimizations = [tf.lite.Optimize.DEFAULT]
39
40  def representative_dataset_gen():
41    for i in range(10):
42      yield [np.array([i] * input_size, dtype=np.float32)]
43
44  converter.representative_dataset = representative_dataset_gen
45  converter.target_spec.supported_ops = [supported_ops]
46  tflite_model = converter.convert()
47
48  return tflite_model
49
50
51class ModifyModelInterfaceTest(test_util.TensorFlowTestCase):
52
53  def testInt8Interface(self):
54    # 1. SETUP
55    # Define the temporary directory and files
56    temp_dir = self.get_temp_dir()
57    initial_file = os.path.join(temp_dir, 'initial_model.tflite')
58    final_file = os.path.join(temp_dir, 'final_model.tflite')
59    # Define initial model
60    initial_model = build_tflite_model_with_full_integer_quantization()
61    with open(initial_file, 'wb') as model_file:
62      model_file.write(initial_model)
63
64    # 2. INVOKE
65    # Invoke the modify_model_interface function
66    modify_model_interface_lib.modify_model_interface(initial_file, final_file,
67                                                      tf.int8, tf.int8)
68
69    # 3. VALIDATE
70    # Load TFLite model and allocate tensors.
71    initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
72    initial_interpreter.allocate_tensors()
73    final_interpreter = tf.lite.Interpreter(model_path=final_file)
74    final_interpreter.allocate_tensors()
75
76    # Get input and output types.
77    initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
78    initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
79    final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
80    final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
81
82    # Validate the model interfaces
83    self.assertEqual(initial_input_dtype, np.float32)
84    self.assertEqual(initial_output_dtype, np.float32)
85    self.assertEqual(final_input_dtype, np.int8)
86    self.assertEqual(final_output_dtype, np.int8)
87
88  def testInt16Interface(self):
89    # 1. SETUP
90    # Define the temporary directory and files
91    temp_dir = self.get_temp_dir()
92    initial_file = os.path.join(temp_dir, 'initial_model.tflite')
93    final_file = os.path.join(temp_dir, 'final_model.tflite')
94    # Define initial model
95    initial_model = build_tflite_model_with_full_integer_quantization(
96        supported_ops=tf.lite.OpsSet
97        .EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8)
98    with open(initial_file, 'wb') as model_file:
99      model_file.write(initial_model)
100
101    # 2. INVOKE
102    # Invoke the modify_model_interface function
103    modify_model_interface_lib.modify_model_interface(initial_file, final_file,
104                                                      tf.int16, tf.int16)
105
106    # 3. VALIDATE
107    # Load TFLite model and allocate tensors.
108    initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
109    initial_interpreter.allocate_tensors()
110    final_interpreter = tf.lite.Interpreter(model_path=final_file)
111    final_interpreter.allocate_tensors()
112
113    # Get input and output types.
114    initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
115    initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
116    final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
117    final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
118
119    # Validate the model interfaces
120    self.assertEqual(initial_input_dtype, np.float32)
121    self.assertEqual(initial_output_dtype, np.float32)
122    self.assertEqual(final_input_dtype, np.int16)
123    self.assertEqual(final_output_dtype, np.int16)
124
125  def testUInt8Interface(self):
126    # 1. SETUP
127    # Define the temporary directory and files
128    temp_dir = self.get_temp_dir()
129    initial_file = os.path.join(temp_dir, 'initial_model.tflite')
130    final_file = os.path.join(temp_dir, 'final_model.tflite')
131    # Define initial model
132    initial_model = build_tflite_model_with_full_integer_quantization()
133    with open(initial_file, 'wb') as model_file:
134      model_file.write(initial_model)
135
136    # 2. INVOKE
137    # Invoke the modify_model_interface function
138    modify_model_interface_lib.modify_model_interface(initial_file, final_file,
139                                                      tf.uint8, tf.uint8)
140
141    # 3. VALIDATE
142    # Load TFLite model and allocate tensors.
143    initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
144    initial_interpreter.allocate_tensors()
145    final_interpreter = tf.lite.Interpreter(model_path=final_file)
146    final_interpreter.allocate_tensors()
147
148    # Get input and output types.
149    initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
150    initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
151    final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
152    final_output_dtype = final_interpreter.get_output_details()[0]['dtype']
153
154    # Validate the model interfaces
155    self.assertEqual(initial_input_dtype, np.float32)
156    self.assertEqual(initial_output_dtype, np.float32)
157    self.assertEqual(final_input_dtype, np.uint8)
158    self.assertEqual(final_output_dtype, np.uint8)
159
160
161if __name__ == '__main__':
162  test.main()
163