1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24 25class TfCastTest(test.TestCase): 26 27 def test_cast_unsigned_signed_i32(self): 28 mlir_function = """ 29 func.func @test(%arg0: tensor<?xui32>) -> tensor<?xi32> { 30 %0 = "tf.Cast"(%arg0) : (tensor<?xui32>) -> tensor<?xi32> 31 func.return %0 : tensor<?xi32> 32 }""" 33 34 compiled = jitrt.compile(mlir_function, 'test') 35 36 arg0 = np.random.uniform(300, 3000, size=10).astype(np.uint32) 37 38 [res] = jitrt.execute(compiled, [arg0]) 39 np.testing.assert_equal(res, arg0) 40 np.testing.assert_equal(res.dtype, np.int32) 41 42 def test_cast_signed_unsigned_i32(self): 43 mlir_function = """ 44 func.func @test(%arg0: tensor<?xi32>) -> tensor<?xui32> { 45 %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xui32> 46 func.return %0 : tensor<?xui32> 47 }""" 48 49 compiled = jitrt.compile(mlir_function, 'test') 50 51 arg0 = np.random.uniform(300, 3000, size=10).astype(np.int32) 52 53 [res] = jitrt.execute(compiled, [arg0]) 54 np.testing.assert_equal(res, arg0) 55 np.testing.assert_equal(res.dtype, np.uint32) 56 57 def test_cast_unsigned_signed_i32_i64(self): 58 mlir_function = """ 59 func.func @test(%arg0: tensor<?xui32>) -> tensor<?xi64> { 60 %0 = "tf.Cast"(%arg0) : (tensor<?xui32>) -> tensor<?xi64> 61 func.return %0 : tensor<?xi64> 62 }""" 63 64 compiled = jitrt.compile(mlir_function, 'test') 65 66 arg0 = np.random.uniform(300, 3000, size=10).astype(np.uint32) 67 68 [res] = jitrt.execute(compiled, [arg0]) 69 np.testing.assert_equal(res, arg0) 70 np.testing.assert_equal(res.dtype, np.int64) 71 72 def test_cast_signed_unsigned_i64_i8(self): 73 mlir_function = """ 74 func.func @test(%arg0: tensor<?xi64>) -> tensor<?xui8> { 75 %0 = "tf.Cast"(%arg0) : (tensor<?xi64>) -> tensor<?xui8> 76 func.return %0 : tensor<?xui8> 77 }""" 78 79 compiled = jitrt.compile(mlir_function, 'test') 80 81 arg0 = np.random.uniform(300, 3000, size=10).astype(np.int64) 82 83 [res] = jitrt.execute(compiled, [arg0]) 84 np.testing.assert_equal(res, arg0.astype(np.uint8)) 85 np.testing.assert_equal(res.dtype, np.uint8) 86 87 88if __name__ == '__main__': 89 test.main() 90