1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for python.compiler.mlir.""" 16 17from tensorflow.python.compiler.mlir import mlir 18from tensorflow.python.eager import def_function 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import tensor_spec 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import logging_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.platform import test 26from tensorflow.python.pywrap_mlir import import_graphdef 27 28 29class MLIRGraphDefImportTest(test.TestCase): 30 31 def testImport(self): 32 """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`.""" 33 mlir_module = mlir.convert_graph_def('') 34 # An empty graph should contain at least an empty main function. 35 self.assertIn('func @main', mlir_module) 36 37 def testInvalidPbtxt(self): 38 with self.assertRaisesRegex(errors.InvalidArgumentError, 39 'Could not parse input proto'): 40 mlir.convert_graph_def('some invalid proto') 41 42 def testGraphDefToTf(self): 43 """Tests the basic flow of `tf.mlir.experimental.convert_graph_def` 44 45 with tf-standard-pipeline converting all the way to the TF dialect. 46 """ 47 48 tensor_shape = (10, 10) 49 50 @def_function.function( 51 input_signature=( 52 tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32), 53 tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32), 54 )) 55 def add_func(lhs, rhs): 56 return math_ops.add(lhs, rhs) 57 58 tf_graph_def = add_func.get_concrete_function().graph.as_graph_def() 59 60 mlir_tf = import_graphdef( 61 tf_graph_def, 62 "tf-standard-pipeline", 63 False, 64 input_names=["lhs", "rhs"], 65 input_data_types=["DT_FLOAT", "DT_FLOAT"], 66 input_data_shapes=["10,10", "10,10"], 67 output_names=["Add"]) 68 # Check whether the mlir-function signature has the mentioned 69 # inputs and outputs. 70 self.assertRegex( 71 mlir_tf, 72 r"func @main\(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>") 73 self.assertRegex(mlir_tf, r'inputs = "lhs,rhs"') 74 self.assertRegex(mlir_tf, r'outputs = "Add"') 75 76 # Same check with scalar input (empty input shape). 77 mlir_tf = import_graphdef( 78 tf_graph_def, 79 "tf-standard-pipeline", 80 False, 81 input_names=["lhs", "rhs"], 82 input_data_types=["DT_FLOAT", "DT_FLOAT"], 83 input_data_shapes=["", ""], 84 output_names=["Add"]) 85 self.assertRegex(mlir_tf, 86 r"func @main\(%arg0: tensor<f32>, %arg1: tensor<f32>") 87 88 # Test invalid test cases where no. of input names is invalid/wrong. 89 with self.assertRaisesRegex( 90 errors.InvalidArgumentError, 91 "Length of input node array and data type doesn't match"): 92 93 import_graphdef( 94 tf_graph_def, 95 "tf-standard-pipeline", 96 False, 97 input_names=["lhs"], 98 input_data_types=["DT_FLOAT", "DT_FLOAT"], 99 input_data_shapes=["10,10", "10,10"], 100 output_names=["Add"]) 101 102 # Test invalid test cases where the input shapes argument is wrong. 103 with self.assertRaisesRegex(errors.InvalidArgumentError, 104 "Dimensions must be equal"): 105 106 import_graphdef( 107 tf_graph_def, 108 "tf-standard-pipeline", 109 False, 110 input_names=["lhs", "rhs"], 111 input_data_types=["DT_FLOAT", "DT_FLOAT"], 112 input_data_shapes=["10,11", "10,10"], 113 output_names=["Add"]) 114 115 116class MLIRConcreteFunctionImportTest(test.TestCase): 117 118 @test_util.run_v2_only 119 def testImport(self): 120 121 @def_function.function 122 def sqr(i): 123 return i * i 124 125 concrete_function = sqr.get_concrete_function( 126 tensor_spec.TensorSpec(None, dtypes.float32)) 127 mlir_module = mlir.convert_function(concrete_function, show_debug_info=True) 128 self.assertRegex(mlir_module, r'func @.*sqr.*\(') 129 self.assertRegex(mlir_module, r'callsite\(".*mlir_test.py":') 130 131 @test_util.run_v2_only 132 def testImportWithCall(self): 133 134 @def_function.function 135 def callee(i): 136 return i 137 138 @def_function.function 139 def caller(i): 140 return callee(i) 141 142 concrete_function = caller.get_concrete_function( 143 tensor_spec.TensorSpec(None, dtypes.float32)) 144 mlir_module = mlir.convert_function(concrete_function) 145 self.assertRegex(mlir_module, r'func @.*caller.*\(') 146 self.assertRegex(mlir_module, r'func private @.*callee.*\(') 147 148 @test_util.run_v2_only 149 def testImportWithControlRet(self): 150 151 @def_function.function 152 def logging(): 153 logging_ops.print_v2('some message') 154 155 concrete_function = logging.get_concrete_function() 156 mlir_module = mlir.convert_function(concrete_function, pass_pipeline='') 157 self.assertRegex(mlir_module, r'tf\.PrintV2') 158 self.assertRegex(mlir_module, r'tf_executor.fetch.*: !tf_executor.control') 159 160 161if __name__ == '__main__': 162 test.main() 163