• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28jitrt = tf_jitrt.TfJitRtExecutor()
29
30
31class TfSelect(test.TestCase):
32
33  def test_select_1d(self):
34    for specialize in specializations:
35      mlir_function = """
36        func.func @test(%arg0: tensor<?xf32>)
37               -> (tensor<?xf32>, tensor<?xi1>, tensor<?xf32>)
38        {
39          %c = "tf.Const"() {value = dense<0.0> : tensor<f32>}
40               : () -> tensor<f32>
41          %0 = "tf.ZerosLike"(%arg0)
42               : (tensor<?xf32>) -> tensor<?xf32>
43          %1 = "tf.Less"(%arg0, %c)
44               : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1>
45          %2 = "tf.Select"(%1, %0, %arg0)
46               : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
47          func.return %0, %1, %2 : tensor<?xf32>, tensor<?xi1>, tensor<?xf32>
48        }"""
49
50      compiled = jitrt.compile(mlir_function, 'test', specialize)
51
52      d0 = np.random.randint(1, 10)
53      arg0 = np.random.uniform(0, 10.0, size=(d0)).astype(np.float32)
54
55      [zeros, less, res] = jitrt.execute(compiled, [arg0])
56      np.testing.assert_allclose(zeros, np.zeros_like(arg0), atol=0.0)
57      np.testing.assert_allclose(less, np.less(arg0, 0.0), atol=0.0)
58      np.testing.assert_allclose(res, np.clip(arg0, 0.0, None), atol=0.0)
59
60
61if __name__ == '__main__':
62  test.main()
63