• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for analyzer package."""
16
17import io
18import sys
19
20import tensorflow as tf
21
22from tensorflow.lite.python import analyzer
23from tensorflow.python.framework import test_util
24from tensorflow.python.platform import resource_loader
25from tensorflow.python.platform import test
26
27
28class AnalyzerTest(test_util.TensorFlowTestCase):
29
30  def testTxt(self):
31    model_path = resource_loader.get_path_to_datafile('../testdata/add.bin')
32    mock_stdout = io.StringIO()
33    with test.mock.patch.object(sys, 'stdout', mock_stdout):
34      analyzer.ModelAnalyzer.analyze(model_path=model_path)
35    txt = mock_stdout.getvalue()
36    self.assertIn('Subgraph#0(T#1) -> [T#2]', txt)
37    self.assertIn('Op#0 ADD(T#1, T#1) -> [T#0]', txt)
38    self.assertIn('Op#1 ADD(T#0, T#1) -> [T#2]', txt)
39    self.assertNotIn('Your model looks compatibile with GPU delegate', txt)
40
41  def testMlir(self):
42    model_path = resource_loader.get_path_to_datafile('../testdata/add.bin')
43    mock_stdout = io.StringIO()
44    with test.mock.patch.object(sys, 'stdout', mock_stdout):
45      analyzer.ModelAnalyzer.analyze(
46          model_path=model_path, experimental_use_mlir=True)
47    mlir = mock_stdout.getvalue()
48    self.assertIn(
49        'func @main(%arg0: tensor<1x8x8x3xf32>) -> '
50        'tensor<1x8x8x3xf32> attributes '
51        '{tf.entry_function = {inputs = "input", outputs = "output"}}', mlir)
52    self.assertIn(
53        '%0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE"} : '
54        'tensor<1x8x8x3xf32>', mlir)
55    self.assertIn(
56        '%1 = tfl.add %0, %arg0 {fused_activation_function = "NONE"} : '
57        'tensor<1x8x8x3xf32>', mlir)
58    self.assertIn('return %1 : tensor<1x8x8x3xf32>', mlir)
59
60  def testMlirHugeConst(self):
61    model_path = resource_loader.get_path_to_datafile(
62        '../testdata/conv_huge_im2col.bin')
63    mock_stdout = io.StringIO()
64    with test.mock.patch.object(sys, 'stdout', mock_stdout):
65      analyzer.ModelAnalyzer.analyze(
66          model_path=model_path, experimental_use_mlir=True)
67    mlir = mock_stdout.getvalue()
68    self.assertIn(
69        '%1 = "tfl.pseudo_const"() {value = opaque<"_", "0xDEADBEEF"> : '
70        'tensor<3x3x3x8xf32>} : () -> tensor<3x3x3x8xf32>', mlir)
71
72  def testTxtWithFlatBufferModel(self):
73
74    @tf.function(
75        input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
76    def func(x):
77      return x + tf.cos(x)
78
79    converter = tf.lite.TFLiteConverter.from_concrete_functions(
80        [func.get_concrete_function()], func)
81    fb_model = converter.convert()
82    mock_stdout = io.StringIO()
83    with test.mock.patch.object(sys, 'stdout', mock_stdout):
84      analyzer.ModelAnalyzer.analyze(model_content=fb_model)
85    txt = mock_stdout.getvalue()
86    self.assertIn('Subgraph#0 main(T#0) -> [T#2]', txt)
87    self.assertIn('Op#0 COS(T#0) -> [T#1]', txt)
88    self.assertIn('Op#1 ADD(T#0, T#1) -> [T#2]', txt)
89
90  def testMlirWithFlatBufferModel(self):
91
92    @tf.function(
93        input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
94    def func(x):
95      return x + tf.cos(x)
96
97    converter = tf.lite.TFLiteConverter.from_concrete_functions(
98        [func.get_concrete_function()], func)
99    fb_model = converter.convert()
100    mock_stdout = io.StringIO()
101    with test.mock.patch.object(sys, 'stdout', mock_stdout):
102      analyzer.ModelAnalyzer.analyze(
103          model_content=fb_model, experimental_use_mlir=True)
104    mlir = mock_stdout.getvalue()
105    self.assertIn('func @main(%arg0: tensor<?xf32>) -> tensor<?xf32>', mlir)
106    self.assertIn('%0 = "tfl.cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>',
107                  mlir)
108    self.assertIn(
109        '%1 = tfl.add %arg0, %0 {fused_activation_function = "NONE"} : '
110        'tensor<?xf32>', mlir)
111    self.assertIn('return %1 : tensor<?xf32', mlir)
112
113  def testTxtGpuCompatiblity(self):
114    model_path = resource_loader.get_path_to_datafile(
115        '../testdata/multi_add_flex.bin')
116    mock_stdout = io.StringIO()
117    with test.mock.patch.object(sys, 'stdout', mock_stdout):
118      analyzer.ModelAnalyzer.analyze(
119          model_path=model_path, gpu_compatibility=True)
120    txt = mock_stdout.getvalue()
121    self.assertIn(
122        'GPU COMPATIBILITY WARNING: Not supported custom op FlexAddV2', txt)
123    self.assertIn(
124        'GPU COMPATIBILITY WARNING: Subgraph#0 has GPU delegate compatibility '
125        'issues at nodes 0, 1, 2', txt)
126
127  def testTxtGpuCompatiblityPass(self):
128
129    @tf.function(
130        input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
131    def func(x):
132      return x + tf.cos(x)
133
134    converter = tf.lite.TFLiteConverter.from_concrete_functions(
135        [func.get_concrete_function()], func)
136    fb_model = converter.convert()
137    mock_stdout = io.StringIO()
138    with test.mock.patch.object(sys, 'stdout', mock_stdout):
139      analyzer.ModelAnalyzer.analyze(
140          model_content=fb_model, gpu_compatibility=True)
141    txt = mock_stdout.getvalue()
142    self.assertIn(
143        'Your model looks compatibile with GPU delegate with TFLite runtime',
144        txt)
145
146
147if __name__ == '__main__':
148  test.main()
149