1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28jitrt = tf_jitrt.TfJitRtExecutor() 29 30 31class TfTransposeTest(test.TestCase): 32 33 def test_transpose_2d(self): 34 for specialize in specializations: 35 mlir_function = """ 36 func.func @test(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { 37 %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } 38 : () -> tensor<2xi32> 39 %1 = "tf.Transpose"(%arg0, %0) 40 : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 41 func.return %1 : tensor<?x?xf32> 42 }""" 43 44 compiled = jitrt.compile( 45 mlir_function, 46 'test', 47 specialize, 48 vectorize=True, 49 codegen_transpose=True) 50 51 d0 = np.random.randint(1, 10) 52 d1 = np.random.randint(1, 10) 53 54 arg0 = np.random.uniform(0, 10.0, size=(d0, d1)).astype(np.float32) 55 56 [res] = jitrt.execute(compiled, [arg0]) 57 np.testing.assert_allclose(res, np.transpose(arg0), atol=0.0) 58 59 def test_transpose_3d_0_2_1(self): 60 for specialize in specializations: 61 mlir_function = """ 62 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 63 %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi64> } 64 : () -> tensor<3xi64> 65 %1 = "tf.Transpose"(%arg0, %0) 66 : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 67 func.return %1 : tensor<?x?x?xf32> 68 }""" 69 70 compiled = jitrt.compile( 71 mlir_function, 72 'test', 73 specialize, 74 vectorize=True, 75 codegen_transpose=True) 76 77 dim_size = 32 78 arg0 = np.arange(0, dim_size * dim_size * dim_size, 1, 79 np.float32).reshape((dim_size, dim_size, dim_size)) 80 81 [res] = jitrt.execute(compiled, [arg0]) 82 np.testing.assert_array_equal(res, np.transpose(arg0, (0, 2, 1))) 83 84 def test_transpose_3d_2_0_1(self): 85 for specialize in specializations: 86 mlir_function = """ 87 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 88 %0 = "tf.Const"() { value = dense<[2, 0, 1]> : tensor<3xi64> } 89 : () -> tensor<3xi64> 90 %1 = "tf.Transpose"(%arg0, %0) 91 : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 92 func.return %1 : tensor<?x?x?xf32> 93 }""" 94 95 compiled = jitrt.compile( 96 mlir_function, 97 'test', 98 specialize, 99 vectorize=True, 100 codegen_transpose=True) 101 102 dim_size = 32 103 arg0 = np.arange(0, dim_size * dim_size * dim_size, 1, 104 np.float32).reshape((dim_size, dim_size, dim_size)) 105 106 [res] = jitrt.execute(compiled, [arg0]) 107 np.testing.assert_array_equal(res, np.transpose(arg0, (2, 0, 1))) 108 109 def test_transpose_3d_2_1_0(self): 110 for specialize in specializations: 111 mlir_function = """ 112 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 113 %0 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi64> } 114 : () -> tensor<3xi64> 115 %1 = "tf.Transpose"(%arg0, %0) 116 : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 117 func.return %1 : tensor<?x?x?xf32> 118 }""" 119 120 compiled = jitrt.compile( 121 mlir_function, 122 'test', 123 specialize, 124 vectorize=True, 125 codegen_transpose=True) 126 127 dim_size = 32 128 arg0 = np.arange(0, dim_size * dim_size * dim_size, 1, 129 np.float32).reshape((dim_size, dim_size, dim_size)) 130 131 [res] = jitrt.execute(compiled, [arg0]) 132 np.testing.assert_array_equal(res, np.transpose(arg0, (2, 1, 0))) 133 134 def test_transpose_3d_1_2_0(self): 135 for specialize in specializations: 136 mlir_function = """ 137 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 138 %0 = "tf.Const"() { value = dense<[1, 2, 0]> : tensor<3xi64> } 139 : () -> tensor<3xi64> 140 %1 = "tf.Transpose"(%arg0, %0) 141 : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 142 func.return %1 : tensor<?x?x?xf32> 143 }""" 144 145 compiled = jitrt.compile( 146 mlir_function, 147 'test', 148 specialize, 149 vectorize=True, 150 codegen_transpose=True) 151 152 dim_size = 32 153 arg0 = np.arange(0, dim_size * dim_size * dim_size, 1, 154 np.float32).reshape((dim_size, dim_size, dim_size)) 155 156 [res] = jitrt.execute(compiled, [arg0]) 157 np.testing.assert_array_equal(res, np.transpose(arg0, (1, 2, 0))) 158 159 def test_transpose_3d_1_0_2(self): 160 for specialize in specializations: 161 mlir_function = """ 162 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 163 %0 = "tf.Const"() { value = dense<[1, 0, 2]> : tensor<3xi64> } 164 : () -> tensor<3xi64> 165 %1 = "tf.Transpose"(%arg0, %0) 166 : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> 167 func.return %1 : tensor<?x?x?xf32> 168 }""" 169 170 compiled = jitrt.compile( 171 mlir_function, 172 'test', 173 specialize, 174 vectorize=True, 175 codegen_transpose=True) 176 177 dim_size = 32 178 arg0 = np.arange(0, dim_size * dim_size * dim_size, 1, 179 np.float32).reshape((dim_size, dim_size, dim_size)) 180 181 [res] = jitrt.execute(compiled, [arg0]) 182 np.testing.assert_array_equal(res, np.transpose(arg0, (1, 0, 2))) 183 184 def test_double_transpose_3d(self): 185 for specialize in specializations: 186 mlir_function = """ 187 func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { 188 %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> } 189 : () -> tensor<3xi32> 190 %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> } 191 : () -> tensor<3xi32> 192 %2 = "tf.Transpose"(%arg0, %0) 193 : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> 194 %3 = "tf.Transpose"(%2, %1) 195 : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> 196 func.return %3 : tensor<?x?x?xf32> 197 }""" 198 199 compiled = jitrt.compile( 200 mlir_function, 201 'test', 202 specialize, 203 vectorize=True, 204 codegen_transpose=True) 205 206 d0 = np.random.randint(1, 10) 207 d1 = np.random.randint(1, 10) 208 d2 = np.random.randint(1, 10) 209 210 arg0 = np.random.uniform(0, 10.0, size=(d0, d1, d2)).astype(np.float32) 211 212 [res] = jitrt.execute(compiled, [arg0]) 213 ref = np.transpose(np.transpose(arg0, (0, 2, 1)), (2, 1, 0)) 214 np.testing.assert_allclose(res, ref, atol=0.0) 215 216 # Without value specialization, the below tf.Transpose won't compile because 217 # the permutation vector must be statically shaped. 218 def test_transpose_value_specialization_i32(self): 219 mlir_function = """ 220 func.func @compute(%arg0: tensor<*xf32>, 221 %arg1: tensor<?xi32> {rt.constraint = "value"}) 222 -> tensor<*xf32> { 223 %0 = "tf.Transpose"(%arg0, %arg1) 224 : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32> 225 func.return %0 : tensor<*xf32> 226 }""" 227 compiled = jitrt.compile(mlir_function, 'compute') 228 tensor = np.random.uniform(0, 10.0, size=(3, 3)).astype(np.float32) 229 perm0 = np.array([1, 0]).astype(np.int32) 230 perm1 = np.array([0, 1]).astype(np.int32) 231 232 # Test that the same compiled module with two different value-specialized 233 # arguments is handled correctly, i.e. it is specialized twice. 234 [res0] = jitrt.execute(compiled, [tensor, perm0]) 235 [res1] = jitrt.execute(compiled, [tensor, perm1]) 236 np.testing.assert_allclose(res0, np.transpose(tensor, perm0), atol=0.0) 237 np.testing.assert_allclose(res1, np.transpose(tensor, perm1), atol=0.0) 238 239 # Test value specialization of two i64 operands. 240 def test_transpose_value_specialization_i64(self): 241 mlir_function = """ 242 func.func @compute(%arg0: tensor<*xf32>, 243 %arg1: tensor<?xi64> {rt.constraint = "value"}, 244 %arg2: tensor<?xi64> {rt.constraint = "value"}) 245 -> tensor<*xf32> { 246 %0 = "tf.Transpose"(%arg0, %arg1) 247 : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32> 248 %1 = "tf.Transpose"(%0, %arg2) 249 : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32> 250 func.return %1 : tensor<*xf32> 251 }""" 252 compiled = jitrt.compile(mlir_function, 'compute') 253 tensor = np.random.uniform(0, 10.0, size=(3, 3)).astype(np.float32) 254 perm0 = np.array([1, 0]).astype(np.int64) 255 perm1 = np.array([0, 1]).astype(np.int64) 256 257 [res] = jitrt.execute(compiled, [tensor, perm0, perm1]) 258 np.testing.assert_allclose( 259 res, np.transpose(np.transpose(tensor, perm0), perm1), atol=0.0) 260 261 # Test that without the value constraint the function cannot compile 262 # because the permutation vector is not statically shaped. 263 def test_transpose_die_without_value_specialization(self): 264 mlir_function = """ 265 func.func @compute(%arg0: tensor<*xf32>, 266 %arg1: tensor<?xi64>) -> tensor<*xf32> { 267 %0 = "tf.Transpose"(%arg0, %arg1) 268 : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32> 269 func.return %0 : tensor<*xf32> 270 }""" 271 try: 272 jitrt.compile(mlir_function, 'compute') 273 except Exception: # pylint: disable=broad-except 274 return 275 raise RuntimeError('Compilation should have failed') 276 277 278if __name__ == '__main__': 279 test.main() 280