1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24specializations = [ 25 tf_jitrt.Specialization.ENABLED, 26 tf_jitrt.Specialization.DISABLED, 27 tf_jitrt.Specialization.ALWAYS, 28] 29 30 31class TfBinaryBcastTest(test.TestCase): 32 33 def test_bcast_2d_1d(self): 34 mlir_function = """ 35 func.func @test(%arg0: tensor<?x4xf32>, 36 %arg1: tensor<4xf32>, 37 %arg2: tensor<4xf32>) -> tensor<?x4xf32> { 38 %0 = "tf.Log1p"(%arg0) 39 : (tensor<?x4xf32>) -> tensor<?x4xf32> 40 %1 = "tf.Sub"(%0, %arg1) 41 : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> 42 %2 = "tf.Mul"(%1, %arg2) 43 : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> 44 %3 = "tf.Atan2"(%2, %arg2) 45 : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> 46 func.return %3 : tensor<?x4xf32> 47 }""" 48 49 n = np.random.randint(1, 10) 50 51 arg0 = np.random.uniform(0, 10.0, size=(n, 4)).astype(np.float32) 52 arg1 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32) 53 arg2 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32) 54 55 for specialize in specializations: 56 for vectorize in [True, False]: 57 compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize) 58 59 [res] = jitrt.execute(compiled, [arg0, arg1, arg2]) 60 ref = np.arctan2((np.log1p(arg0) - arg1) * arg2, arg2) 61 np.testing.assert_allclose(res, ref, atol=1e-04) 62 63 def test_bcast_2d_2d(self): 64 mlir_function = """ 65 func.func @test(%arg0: tensor<?x?xf32>, 66 %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 67 %0 = "tf.Mul"(%arg0, %arg1) 68 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 69 func.return %0 : tensor<?x?xf32> 70 }""" 71 72 m = np.random.randint(1, 10) 73 n = np.random.randint(1, 10) 74 75 lhs0 = np.random.uniform(0, 10.0, size=(1, 1)).astype(np.float32) 76 lhs1 = np.random.uniform(0, 10.0, size=(1, n)).astype(np.float32) 77 lhs2 = np.random.uniform(0, 10.0, size=(m, 1)).astype(np.float32) 78 lhs3 = np.random.uniform(0, 10.0, size=(m, n)).astype(np.float32) 79 80 rhs0 = np.random.uniform(0, 10.0, size=(1, 1)).astype(np.float32) 81 rhs1 = np.random.uniform(0, 10.0, size=(1, n)).astype(np.float32) 82 rhs2 = np.random.uniform(0, 10.0, size=(m, 1)).astype(np.float32) 83 rhs3 = np.random.uniform(0, 10.0, size=(m, n)).astype(np.float32) 84 85 for specialize in specializations: 86 compiled = jitrt.compile(mlir_function, 'test', specialize) 87 88 for lhs in [lhs0, lhs1, lhs2, lhs3]: 89 for rhs in [rhs0, rhs1, rhs2, rhs3]: 90 [res] = jitrt.execute(compiled, [lhs, rhs]) 91 np.testing.assert_allclose(res, lhs * rhs, atol=1e-07) 92 93 def test_bcast_2d_1d_0d(self): 94 mlir_function = """ 95 func.func @compute(%arg0: tensor<?x4xf32>, 96 %arg1: tensor<4xf32>, 97 %arg2: tensor<f32>) -> tensor<?x4xf32> { 98 %0 = "tf.AddV2"(%arg1, %arg2) 99 : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 100 %1 = "tf.AddV2"(%arg0, %0) 101 : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> 102 %2 = "tf.AddV2"(%1, %0) 103 : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> 104 func.return %2 : tensor<?x4xf32> 105 }""" 106 107 for specialize in specializations: 108 compiled = jitrt.compile(mlir_function, 'compute', specialize) 109 110 arg0 = np.random.uniform(0, 10.0, size=(1, 4)).astype(np.float32) 111 arg1 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32) 112 arg2 = np.random.uniform(0, 10.0, size=()).astype(np.float32) 113 114 [res] = jitrt.execute(compiled, [arg0, arg1, arg2]) 115 116 # Reference implementation with numpy 117 t_0 = np.add(arg1, arg2) 118 t_1 = np.add(arg0, t_0) 119 t_2 = np.add(t_1, t_0) 120 121 np.testing.assert_allclose(res, t_2, atol=0.0) 122 123 def test_bcast_3d_3d(self): 124 mlir_function = """ 125 func.func @test(%arg0: tensor<?x?x12xf32>, 126 %arg1: tensor<?x?x12xf32>) -> tensor<?x?x12xf32> { 127 %0 = "tf.AddV2"(%arg0, %arg1) 128 : (tensor<?x?x12xf32>, tensor<?x?x12xf32>) -> tensor<?x?x12xf32> 129 func.return %0 : tensor<?x?x12xf32> 130 }""" 131 132 d0 = np.random.randint(1, 10) 133 d1 = np.random.randint(1, 10) 134 135 arg0 = np.random.uniform(0, 10.0, size=(d0, d1, 12)).astype(np.float32) 136 arg1 = np.random.uniform(0, 10.0, size=(d0, d1, 12)).astype(np.float32) 137 138 for specialize in specializations: 139 for vectorize in [True, False]: 140 compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize) 141 142 [res] = jitrt.execute(compiled, [arg0, arg1]) 143 np.testing.assert_allclose(res, arg0 + arg1, atol=0.0) 144 145 def test_bcast_unranked_0d(self): 146 mlir_function = """ 147 func.func @compute(%arg0: tensor<*xf32> {rt.constraint = "rank"}, 148 %arg1: tensor<f32>) -> tensor<*xf32> { 149 %0 = "tf.AddV2"(%arg0, %arg1) 150 : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> 151 func.return %0 : tensor<*xf32> 152 }""" 153 154 compiled = jitrt.compile(mlir_function, 'compute') 155 156 arg0 = np.random.uniform(0, 10.0, size=(4, 4)).astype(np.float32) 157 arg1 = np.random.uniform(0, 10.0, size=()).astype(np.float32) 158 159 [res] = jitrt.execute(compiled, [arg0, arg1]) 160 161 np.testing.assert_allclose(res, np.add(arg0, arg1), atol=0.0) 162 163 def test_bcast_unranked_unranked(self): 164 mlir_function = """ 165 func.func @compute(%arg0: tensor<*xf32> {rt.constraint = "rank"}, 166 %arg1: tensor<*xf32> {rt.constraint = "rank"}) 167 -> tensor<*xf32> { 168 %0 = "tf.AddV2"(%arg0, %arg1) 169 : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 170 func.return %0 : tensor<*xf32> 171 }""" 172 173 compiled = jitrt.compile(mlir_function, 'compute') 174 175 arg0 = np.random.uniform(0, 10.0, size=(1, 4)).astype(np.float32) 176 arg1 = np.random.uniform(0, 10.0, size=(4, 1)).astype(np.float32) 177 178 [res] = jitrt.execute(compiled, [arg0, arg1]) 179 180 np.testing.assert_allclose(res, np.add(arg0, arg1), atol=0.0) 181 182 # Test that the non-broadcastable shapes error is handled at run time. 183 def test_bcast_1d_1d_error(self): 184 mlir_function = """ 185 func.func @compute(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) 186 -> tensor<?xf32> { 187 %0 = "tf.AddV2"(%arg0, %arg1) 188 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 189 func.return %0 : tensor<?xf32> 190 }""" 191 192 arg0 = np.random.uniform(0, 10.0, size=(2)).astype(np.float32) 193 arg1 = np.random.uniform(0, 10.0, size=(3)).astype(np.float32) 194 195 for specialize in specializations: 196 compiled = jitrt.compile(mlir_function, 'compute', specialize) 197 198 with self.assertRaisesRegex(Exception, 'required broadcastable shapes'): 199 jitrt.execute(compiled, [arg0, arg1]) 200 201 # Test that 0-ranked operands are correctly specialized. 202 def test_bcast_value_rank0(self): 203 mlir_function = """ 204 func.func @compute(%arg0: tensor<*xi32>, 205 %arg1: tensor<i32> {rt.constraint = "value"}) 206 -> tensor<*xi32> { 207 %0 = "tf.AddV2"(%arg0, %arg1) 208 : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32> 209 func.return %0 : tensor<*xi32> 210 }""" 211 compiled = jitrt.compile(mlir_function, 'compute') 212 # Test that the same compiled module with two different value-specialized 213 # arguments is handled correctly. 214 tensor = np.random.uniform(0, 10.0, size=(3)).astype(np.int32) 215 rhs0 = np.random.uniform(0, 10.0, size=()).astype(np.int32) 216 rhs1 = np.random.uniform(0, 10.0, size=()).astype(np.int32) 217 [res0] = jitrt.execute(compiled, [tensor, rhs0]) 218 [res1] = jitrt.execute(compiled, [tensor, rhs1]) 219 np.testing.assert_allclose(res0, np.add(tensor, rhs0), atol=0.0) 220 np.testing.assert_allclose(res1, np.add(tensor, rhs1), atol=0.0) 221 222 # Test that the function does not compile when value-specializing an f32. 223 def test_bcast_value_die_if_unsinkable(self): 224 mlir_function = """ 225 func.func @compute(%arg0: tensor<*xf32>, 226 %arg1: tensor<f32> {rt.constraint = "value"}) 227 -> tensor<*xf32> { 228 %0 = "tf.AddV2"(%arg0, %arg1) 229 : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32> 230 func.return %0 : tensor<*xf32> 231 }""" 232 233 with self.assertRaisesRegex(Exception, 234 'cannot sink operand type: tensor<f32>'): 235 jitrt.compile(mlir_function, 'compute') 236 237 238if __name__ == '__main__': 239 test.main() 240