1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24 25class TfReductionTest(test.TestCase): 26 27 def test_1d_sum_dynamic(self): 28 mlir_function = """ 29 func.func @test(%input: tensor<?xf32>) -> tensor<f32> { 30 %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} 31 : () -> tensor<1xi32> 32 %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} 33 : (tensor<?xf32>, tensor<1xi32>) -> tensor<f32> 34 func.return %0 : tensor<f32> 35 }""" 36 37 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 38 39 arg0 = np.random.uniform(1.0, 5.0, size=(10)).astype(np.float32) 40 41 [res] = jitrt.execute(compiled, [arg0]) 42 np.testing.assert_allclose(res, np.sum(arg0, axis=0), atol=0.01) 43 44 def test_1d_max_static(self): 45 mlir_function = """ 46 func.func @test(%input: tensor<10xf32>) -> tensor<f32> { 47 %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} 48 : () -> tensor<1xi32> 49 %0 = "tf.Max"(%input, %dim_to_reduce) {keep_dims = false} 50 : (tensor<10xf32>, tensor<1xi32>) -> tensor<f32> 51 func.return %0 : tensor<f32> 52 }""" 53 54 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 55 56 arg0 = np.random.uniform(1.0, 1.0, size=(10)).astype(np.float32) 57 58 [res] = jitrt.execute(compiled, [arg0]) 59 np.testing.assert_allclose(res, np.max(arg0, axis=0), atol=0.01) 60 61 def test_1d_max_static_no_dims_to_reduce(self): 62 mlir_function = """ 63 func.func @test(%input: tensor<10xf32>) -> tensor<10xf32> { 64 %dim_to_reduce = "tf.Const"() {value = dense<[]> : tensor<0xi32>} 65 : () -> tensor<0xi32> 66 %0 = "tf.Max"(%input, %dim_to_reduce) {keep_dims = false} 67 : (tensor<10xf32>, tensor<0xi32>) -> tensor<10xf32> 68 func.return %0 : tensor<10xf32> 69 }""" 70 71 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 72 73 arg0 = np.random.uniform(1.0, 1.0, size=(10)).astype(np.float32) 74 75 [res] = jitrt.execute(compiled, [arg0]) 76 np.testing.assert_allclose(res, arg0, atol=0.01) 77 78 def test_2d_row_max(self): 79 mlir_function = """ 80 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 81 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 82 : () -> tensor<1xi32> 83 %0 = "tf.Max"(%input, %dim_to_reduce) {keep_dims = false} 84 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 85 func.return %0 : tensor<?xf32> 86 }""" 87 88 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 89 90 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 91 92 [res] = jitrt.execute(compiled, [arg0]) 93 np.testing.assert_allclose(res, np.max(arg0, axis=1), atol=0.01) 94 95 def test_2d_row_min(self): 96 mlir_function = """ 97 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 98 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 99 : () -> tensor<1xi32> 100 %0 = "tf.Min"(%input, %dim_to_reduce) {keep_dims = false} 101 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 102 func.return %0 : tensor<?xf32> 103 }""" 104 105 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 106 107 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 108 109 [res] = jitrt.execute(compiled, [arg0]) 110 np.testing.assert_allclose(res, np.min(arg0, axis=1), atol=0.01) 111 112 def test_2d_row_sum(self): 113 mlir_function = """ 114 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 115 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 116 : () -> tensor<1xi32> 117 %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} 118 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 119 func.return %0 : tensor<?xf32> 120 }""" 121 122 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 123 124 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 125 126 [res] = jitrt.execute(compiled, [arg0]) 127 np.testing.assert_allclose(res, np.sum(arg0, axis=1), atol=0.01) 128 129 def test_2d_row_prod(self): 130 mlir_function = """ 131 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 132 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 133 : () -> tensor<1xi32> 134 %0 = "tf.Prod"(%input, %dim_to_reduce) {keep_dims = false} 135 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 136 func.return %0 : tensor<?xf32> 137 }""" 138 139 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 140 141 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 142 143 [res] = jitrt.execute(compiled, [arg0]) 144 np.testing.assert_allclose( 145 res, np.prod(arg0, axis=1), rtol=3e-07, atol=0.01) 146 147 def test_2d_column_mean(self): 148 mlir_function = """ 149 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 150 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 151 : () -> tensor<1xi32> 152 %0 = "tf.Mean"(%input, %dim_to_reduce) {keep_dims = false} 153 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 154 func.return %0 : tensor<?xf32> 155 }""" 156 157 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 158 159 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 160 161 [res] = jitrt.execute(compiled, [arg0]) 162 np.testing.assert_allclose( 163 res, np.mean(arg0, axis=1), rtol=3e-07, atol=0.01) 164 165 def test_2d_row_any(self): 166 mlir_function = """ 167 func.func @test(%input: tensor<?x?xi1>) -> tensor<?xi1> { 168 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 169 : () -> tensor<1xi32> 170 %0 = "tf.Any"(%input, %dim_to_reduce) {keep_dims = false} 171 : (tensor<?x?xi1>, tensor<1xi32>) -> tensor<?xi1> 172 func.return %0 : tensor<?xi1> 173 }""" 174 175 compiled = jitrt.compile( 176 mlir_function, 'test', vectorize=True, legalize_i1_tensors=True) 177 178 arg0 = np.random.choice(a=[False, True], size=(8, 10)).astype(np.bool) 179 180 [res] = jitrt.execute(compiled, [arg0]) 181 np.testing.assert_equal(res, np.any(arg0, axis=1)) 182 183 def test_2d_row_all(self): 184 mlir_function = """ 185 func.func @test(%input: tensor<?x?xi1>) -> tensor<?xi1> { 186 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 187 : () -> tensor<1xi32> 188 %0 = "tf.All"(%input, %dim_to_reduce) {keep_dims = false} 189 : (tensor<?x?xi1>, tensor<1xi32>) -> tensor<?xi1> 190 func.return %0 : tensor<?xi1> 191 }""" 192 193 compiled = jitrt.compile( 194 mlir_function, 'test', vectorize=True, legalize_i1_tensors=True) 195 196 arg0 = np.random.choice(a=[False, True], size=(40, 2)).astype(np.bool) 197 198 [res] = jitrt.execute(compiled, [arg0]) 199 np.testing.assert_equal(res, np.all(arg0, axis=1)) 200 201 def test_2d_row_sum_static(self): 202 mlir_function = """ 203 func.func @test(%input: tensor<8x8xf32>) -> tensor<8xf32> { 204 %dim_to_reduce = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} 205 : () -> tensor<1xi32> 206 %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} 207 : (tensor<8x8xf32>, tensor<1xi32>) -> tensor<8xf32> 208 func.return %0 : tensor<8xf32> 209 }""" 210 211 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 212 213 arg0 = np.random.uniform(0.0, 10.0, size=(8, 8)).astype(np.float32) 214 215 [res] = jitrt.execute(compiled, [arg0]) 216 np.testing.assert_allclose(res, np.sum(arg0, axis=1), atol=1) 217 218 def test_2d_column_sum(self): 219 mlir_function = """ 220 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xf32> { 221 %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} 222 : () -> tensor<1xi32> 223 %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} 224 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 225 func.return %0 : tensor<?xf32> 226 }""" 227 228 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 229 230 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 231 232 [res] = jitrt.execute(compiled, [arg0]) 233 np.testing.assert_allclose(res, np.sum(arg0, axis=0), atol=0.01) 234 235 def test_2d_column_sum_static(self): 236 mlir_function = """ 237 func.func @test(%input: tensor<8x8xf32>) -> tensor<8xf32> { 238 %dim_to_reduce = "tf.Const"() {value = dense<[0]> : tensor<1xi32>} 239 : () -> tensor<1xi32> 240 %0 = "tf.Sum"(%input, %dim_to_reduce) {keep_dims = false} 241 : (tensor<8x8xf32>, tensor<1xi32>) -> tensor<8xf32> 242 func.return %0 : tensor<8xf32> 243 }""" 244 245 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 246 247 arg0 = np.random.uniform(0.0, 10.0, size=(8, 8)).astype(np.float32) 248 249 [res] = jitrt.execute(compiled, [arg0]) 250 np.testing.assert_allclose(res, np.sum(arg0, axis=0), atol=1) 251 252 def test_2d_row_argmax(self): 253 mlir_function = """ 254 func.func @test(%input: tensor<?x?xf32>) -> tensor<?xi64> { 255 %dim_to_reduce = "tf.Const"() {value = dense<1> : tensor<i32>} 256 : () -> tensor<i32> 257 %0 = "tf.ArgMax"(%input, %dim_to_reduce) 258 : (tensor<?x?xf32>, tensor<i32>) -> tensor<?xi64> 259 func.return %0 : tensor<?xi64> 260 }""" 261 262 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 263 264 arg0 = np.random.uniform(0.0, 10.0, size=(8, 10)).astype(np.float32) 265 266 [res] = jitrt.execute(compiled, [arg0]) 267 np.testing.assert_equal(res, np.argmax(arg0, axis=1)) 268 269if __name__ == '__main__': 270 test.main() 271