1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24 25class TfPackTest(test.TestCase): 26 27 def pack_and_check(self, src, shape, dtype): 28 compiled = jitrt.compile(src, 'test') 29 30 arg0 = np.random.uniform(0, 10.0, size=shape).astype(dtype) 31 arg1 = np.random.uniform(0, 10.0, size=shape).astype(dtype) 32 33 [res] = jitrt.execute(compiled, [arg0, arg1]) 34 np.testing.assert_allclose(res, np.array([arg0, arg1]), atol=0.0) 35 36 def test_pack_0d_f32(self): 37 self.pack_and_check( 38 """ 39 func.func @test(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> { 40 %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64} 41 : (tensor<f32>, tensor<f32>) -> tensor<2xf32> 42 func.return %1 : tensor<2xf32> 43 }""", (), np.float32) 44 45 def test_pack_0d_i32(self): 46 self.pack_and_check( 47 """ 48 func.func @test(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> { 49 %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64} 50 : (tensor<i32>, tensor<i32>) -> tensor<2xi32> 51 func.return %1 : tensor<2xi32> 52 }""", (), np.int32) 53 54 def test_pack_0d_i64(self): 55 self.pack_and_check( 56 """ 57 func.func @test(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<2xi64> { 58 %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64} 59 : (tensor<i64>, tensor<i64>) -> tensor<2xi64> 60 func.return %1 : tensor<2xi64> 61 }""", (), np.int64) 62 63 def test_pack_0d_i1(self): 64 self.pack_and_check( 65 """ 66 func.func @test(%arg0: tensor<i1>, %arg1: tensor<i1>) -> tensor<2xi1> { 67 %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64} 68 : (tensor<i1>, tensor<i1>) -> tensor<2xi1> 69 func.return %1 : tensor<2xi1> 70 }""", (), bool) 71 72 def test_pack_1d_i32(self): 73 self.pack_and_check( 74 """ 75 func.func @test(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) 76 -> tensor<2x4xi32> { 77 %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64} 78 : (tensor<4xi32>, tensor<4xi32>) -> tensor<2x4xi32> 79 func.return %1 : tensor<2x4xi32> 80 }""", (4), np.int32) 81 82 83if __name__ == '__main__': 84 test.main() 85