• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22jitrt = tf_jitrt.TfJitRtExecutor()
23
24
25class TfStridedSliceTest(test.TestCase):
26
27  def test_strided_slice_1d_to_0d(self):
28    mlir_function = """
29      func.func @test(%arg0: tensor<3xi32>) -> tensor<i32> {
30        %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>}
31                 : () -> tensor<1xi32>
32        %cst_1 = "tf.Const"() {value = dense<0> : tensor<1xi32>}
33                 : () -> tensor<1xi32>
34        %0 = "tf.StridedSlice"(%arg0, %cst_1, %cst_0, %cst_0)
35             {
36               begin_mask       = 0 : i64,
37               ellipsis_mask    = 0 : i64,
38               end_mask         = 0 : i64,
39               new_axis_mask    = 0 : i64,
40               shrink_axis_mask = 1 : i64
41             } : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>)
42              -> tensor<i32>
43        func.return %0 : tensor<i32>
44      }"""
45
46    compiled = jitrt.compile(mlir_function, 'test')
47    arg0 = np.array([1, 2, 3], dtype=np.int32)
48    [res] = jitrt.execute(compiled, [arg0])
49    np.testing.assert_allclose(res, arg0[0], atol=0.0)
50
51
52if __name__ == '__main__':
53  test.main()
54