• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22jitrt = tf_jitrt.TfJitRtExecutor()
23
24
25class TfPackTest(test.TestCase):
26
27  def pack_and_check(self, src, shape, dtype):
28    compiled = jitrt.compile(src, 'test')
29
30    arg0 = np.random.uniform(0, 10.0, size=shape).astype(dtype)
31    arg1 = np.random.uniform(0, 10.0, size=shape).astype(dtype)
32
33    [res] = jitrt.execute(compiled, [arg0, arg1])
34    np.testing.assert_allclose(res, np.array([arg0, arg1]), atol=0.0)
35
36  def test_pack_0d_f32(self):
37    self.pack_and_check(
38        """
39      func.func @test(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
40        %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64}
41             : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
42        func.return %1 : tensor<2xf32>
43      }""", (), np.float32)
44
45  def test_pack_0d_i32(self):
46    self.pack_and_check(
47        """
48      func.func @test(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<2xi32> {
49        %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64}
50             : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
51        func.return %1 : tensor<2xi32>
52      }""", (), np.int32)
53
54  def test_pack_0d_i64(self):
55    self.pack_and_check(
56        """
57      func.func @test(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<2xi64> {
58        %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64}
59             : (tensor<i64>, tensor<i64>) -> tensor<2xi64>
60        func.return %1 : tensor<2xi64>
61      }""", (), np.int64)
62
63  def test_pack_0d_i1(self):
64    self.pack_and_check(
65        """
66      func.func @test(%arg0: tensor<i1>, %arg1: tensor<i1>) -> tensor<2xi1> {
67        %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64}
68             : (tensor<i1>, tensor<i1>) -> tensor<2xi1>
69        func.return %1 : tensor<2xi1>
70      }""", (), bool)
71
72  def test_pack_1d_i32(self):
73    self.pack_and_check(
74        """
75      func.func @test(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>)
76         -> tensor<2x4xi32> {
77        %1 = "tf.Pack"(%arg0, %arg1) {axis = 0 : i64}
78             : (tensor<4xi32>, tensor<4xi32>) -> tensor<2x4xi32>
79        func.return %1 : tensor<2x4xi32>
80      }""", (4), np.int32)
81
82
83if __name__ == '__main__':
84  test.main()
85