• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for python.compiler.mlir."""
16
17from tensorflow.python.compiler.mlir import mlir
18from tensorflow.python.eager import def_function
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import tensor_spec
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import logging_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.platform import test
26from tensorflow.python.pywrap_mlir import import_graphdef
27
28
29class MLIRGraphDefImportTest(test.TestCase):
30
31  def testImport(self):
32    """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`."""
33    mlir_module = mlir.convert_graph_def('')
34    # An empty graph should contain at least an empty main function.
35    self.assertIn('func @main', mlir_module)
36
37  def testInvalidPbtxt(self):
38    with self.assertRaisesRegex(errors.InvalidArgumentError,
39                                'Could not parse input proto'):
40      mlir.convert_graph_def('some invalid proto')
41
42  def testGraphDefToTf(self):
43    """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`
44
45        with tf-standard-pipeline converting all the way to the TF dialect.
46    """
47
48    tensor_shape = (10, 10)
49
50    @def_function.function(
51        input_signature=(
52            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
53            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
54        ))
55    def add_func(lhs, rhs):
56      return math_ops.add(lhs, rhs)
57
58    tf_graph_def = add_func.get_concrete_function().graph.as_graph_def()
59
60    mlir_tf = import_graphdef(
61        tf_graph_def,
62        "tf-standard-pipeline",
63        False,
64        input_names=["lhs", "rhs"],
65        input_data_types=["DT_FLOAT", "DT_FLOAT"],
66        input_data_shapes=["10,10", "10,10"],
67        output_names=["Add"])
68    # Check whether the mlir-function signature has the mentioned
69    # inputs and outputs.
70    self.assertRegex(
71        mlir_tf,
72        r"func @main\(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>")
73    self.assertRegex(mlir_tf, r'inputs = "lhs,rhs"')
74    self.assertRegex(mlir_tf, r'outputs = "Add"')
75
76    # Same check with scalar input (empty input shape).
77    mlir_tf = import_graphdef(
78        tf_graph_def,
79        "tf-standard-pipeline",
80        False,
81        input_names=["lhs", "rhs"],
82        input_data_types=["DT_FLOAT", "DT_FLOAT"],
83        input_data_shapes=["", ""],
84        output_names=["Add"])
85    self.assertRegex(mlir_tf,
86                     r"func @main\(%arg0: tensor<f32>, %arg1: tensor<f32>")
87
88    # Test invalid test cases where no. of input names is invalid/wrong.
89    with self.assertRaisesRegex(
90        errors.InvalidArgumentError,
91        "Length of input node array and data type doesn't match"):
92
93      import_graphdef(
94          tf_graph_def,
95          "tf-standard-pipeline",
96          False,
97          input_names=["lhs"],
98          input_data_types=["DT_FLOAT", "DT_FLOAT"],
99          input_data_shapes=["10,10", "10,10"],
100          output_names=["Add"])
101
102    # Test invalid test cases where the input shapes argument is wrong.
103    with self.assertRaisesRegex(errors.InvalidArgumentError,
104                                "Dimensions must be equal"):
105
106      import_graphdef(
107          tf_graph_def,
108          "tf-standard-pipeline",
109          False,
110          input_names=["lhs", "rhs"],
111          input_data_types=["DT_FLOAT", "DT_FLOAT"],
112          input_data_shapes=["10,11", "10,10"],
113          output_names=["Add"])
114
115
116class MLIRConcreteFunctionImportTest(test.TestCase):
117
118  @test_util.run_v2_only
119  def testImport(self):
120
121    @def_function.function
122    def sqr(i):
123      return i * i
124
125    concrete_function = sqr.get_concrete_function(
126        tensor_spec.TensorSpec(None, dtypes.float32))
127    mlir_module = mlir.convert_function(concrete_function, show_debug_info=True)
128    self.assertRegex(mlir_module, r'func @.*sqr.*\(')
129    self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":')
130
131  @test_util.run_v2_only
132  def testImportWithCall(self):
133
134    @def_function.function
135    def callee(i):
136      return i
137
138    @def_function.function
139    def caller(i):
140      return callee(i)
141
142    concrete_function = caller.get_concrete_function(
143        tensor_spec.TensorSpec(None, dtypes.float32))
144    mlir_module = mlir.convert_function(concrete_function)
145    self.assertRegex(mlir_module, r'func @.*caller.*\(')
146    self.assertRegex(mlir_module, r'func private @.*callee.*\(')
147
148  @test_util.run_v2_only
149  def testImportWithControlRet(self):
150
151    @def_function.function
152    def logging():
153      logging_ops.print_v2('some message')
154
155    concrete_function = logging.get_concrete_function()
156    mlir_module = mlir.convert_function(concrete_function, pass_pipeline='')
157    self.assertRegex(mlir_module, r'tf\.PrintV2')
158    self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control')
159
160
161if __name__ == '__main__':
162  test.main()
163