1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28jitrt = tf_jitrt.TfJitRtExecutor() 29 30 31class TfControlflowTest(test.TestCase): 32 33 def test_if(self): 34 for specialize in specializations: 35 mlir_function = """ 36 func.func @test(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<?xf32>, 37 %arg3: tensor<?xf32>) -> tensor<?xf32> { 38 %0 = "tf.IfRegion"(%arg0) ({ 39 %1 = "tf.If"(%arg1, %arg2, %arg3) 40 {then_branch = @add, else_branch = @sub, is_stateless = true} 41 : (tensor<i1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 42 "tf.Yield"(%1) : (tensor<?xf32>) -> () 43 }, { 44 %2 = "tf.Mul"(%arg2, %arg3) : (tensor<?xf32>, tensor<?xf32>) 45 -> tensor<?xf32> 46 "tf.Yield"(%2) : (tensor<?xf32>) -> () 47 }) {is_stateless = false} : (tensor<i1>) -> tensor<?xf32> 48 func.return %0: tensor<?xf32> 49 } 50 51 func.func @add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 52 %0 = "tf.Add"(%arg0, %arg1): (tensor<?xf32>, tensor<?xf32>) 53 -> tensor<?xf32> 54 func.return %0 : tensor<?xf32> 55 } 56 57 func.func @sub(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { 58 %0 = "tf.Sub"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) 59 -> tensor<?xf32> 60 func.return %0 : tensor<?xf32> 61 }""" 62 compiled = jitrt.compile(mlir_function, 'test', specialize) 63 64 d0 = np.random.randint(1, 100) 65 66 arg0 = np.random.uniform(0.0, 10.0, size=(d0)).astype(np.float32) 67 arg1 = np.random.uniform(0.0, 10.0, size=(d0)).astype(np.float32) 68 69 true = np.array(True) 70 false = np.array(False) 71 [res] = jitrt.execute(compiled, [false, false, arg0, arg1]) 72 np.testing.assert_allclose(res, arg0 * arg1) 73 [res] = jitrt.execute(compiled, [false, true, arg0, arg1]) 74 np.testing.assert_allclose(res, arg0 * arg1) 75 [res] = jitrt.execute(compiled, [true, false, arg0, arg1]) 76 np.testing.assert_allclose(res, arg0 - arg1) 77 [res] = jitrt.execute(compiled, [true, true, arg0, arg1]) 78 np.testing.assert_allclose(res, arg0 + arg1) 79 80 def test_while(self): 81 for specialize in specializations: 82 # Square input until one element is over 100. 83 mlir_function = """ 84 func.func @test(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { 85 %0 = "tf.While"(%arg0) 86 {body = @while_body, cond = @while_cond, is_stateless = true} 87 : (tensor<?x?xf32>) -> (tensor<?x?xf32>) 88 func.return %0: tensor<?x?xf32> 89 } 90 91 func.func @while_body(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { 92 %0 = "tf.Square"(%arg0): (tensor<?x?xf32>) -> tensor<?x?xf32> 93 func.return %0: tensor<?x?xf32> 94 } 95 96 func.func @while_cond(%arg0: tensor<?x?xf32>) -> tensor<i1> { 97 %cst = "tf.Const"() {value = dense<100.0> : tensor<f32>} 98 : () -> tensor<f32> 99 %less = "tf.Less"(%arg0, %cst) {T = f32} 100 : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> 101 %dim_to_reduce = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} 102 : () -> tensor<2xi32> 103 %all = "tf.All"(%less, %dim_to_reduce) {keep_dims = false} 104 : (tensor<?x?xi1>, tensor<2xi32>) -> tensor<i1> 105 func.return %all : tensor<i1> 106 }""" 107 compiled = jitrt.compile(mlir_function, 'test', specialize) 108 109 d0 = np.random.randint(1, 100) 110 d1 = np.random.randint(1, 100) 111 112 arg0 = np.random.uniform(2.0, 10.0, size=(d0, d1)).astype(np.float32) 113 114 np_res = arg0 115 while np.all(np.less(np_res, 100)): 116 np_res = np_res * np_res 117 118 [res] = jitrt.execute(compiled, [arg0]) 119 np.testing.assert_allclose(res, np_res) 120 121 122if __name__ == '__main__': 123 test.main() 124