• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28jitrt = tf_jitrt.TfJitRtExecutor()
29
30
31class TfControlflowTest(test.TestCase):
32
33  def test_if(self):
34    for specialize in specializations:
35      mlir_function = """
36        func.func @test(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<?xf32>,
37                   %arg3: tensor<?xf32>) -> tensor<?xf32> {
38          %0 = "tf.IfRegion"(%arg0) ({
39              %1 = "tf.If"(%arg1, %arg2, %arg3)
40                 {then_branch = @add, else_branch = @sub, is_stateless = true}
41                 : (tensor<i1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
42              "tf.Yield"(%1) : (tensor<?xf32>) -> ()
43            }, {
44              %2 = "tf.Mul"(%arg2, %arg3) : (tensor<?xf32>, tensor<?xf32>)
45                 -> tensor<?xf32>
46              "tf.Yield"(%2) : (tensor<?xf32>) -> ()
47            }) {is_stateless = false} : (tensor<i1>) -> tensor<?xf32>
48          func.return %0: tensor<?xf32>
49        }
50
51        func.func @add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
52          %0 = "tf.Add"(%arg0, %arg1): (tensor<?xf32>, tensor<?xf32>)
53             -> tensor<?xf32>
54          func.return %0 : tensor<?xf32>
55        }
56
57        func.func @sub(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
58          %0 = "tf.Sub"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>)
59             -> tensor<?xf32>
60          func.return %0 : tensor<?xf32>
61        }"""
62      compiled = jitrt.compile(mlir_function, 'test', specialize)
63
64      d0 = np.random.randint(1, 100)
65
66      arg0 = np.random.uniform(0.0, 10.0, size=(d0)).astype(np.float32)
67      arg1 = np.random.uniform(0.0, 10.0, size=(d0)).astype(np.float32)
68
69      true = np.array(True)
70      false = np.array(False)
71      [res] = jitrt.execute(compiled, [false, false, arg0, arg1])
72      np.testing.assert_allclose(res, arg0 * arg1)
73      [res] = jitrt.execute(compiled, [false, true, arg0, arg1])
74      np.testing.assert_allclose(res, arg0 * arg1)
75      [res] = jitrt.execute(compiled, [true, false, arg0, arg1])
76      np.testing.assert_allclose(res, arg0 - arg1)
77      [res] = jitrt.execute(compiled, [true, true, arg0, arg1])
78      np.testing.assert_allclose(res, arg0 + arg1)
79
80  def test_while(self):
81    for specialize in specializations:
82      # Square input until one element is over 100.
83      mlir_function = """
84        func.func @test(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
85          %0 = "tf.While"(%arg0)
86             {body = @while_body, cond = @while_cond, is_stateless = true}
87             : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
88          func.return %0: tensor<?x?xf32>
89        }
90
91        func.func @while_body(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
92          %0 = "tf.Square"(%arg0): (tensor<?x?xf32>) -> tensor<?x?xf32>
93          func.return %0: tensor<?x?xf32>
94        }
95
96        func.func @while_cond(%arg0: tensor<?x?xf32>) -> tensor<i1> {
97          %cst = "tf.Const"() {value = dense<100.0> : tensor<f32>}
98             : () -> tensor<f32>
99          %less = "tf.Less"(%arg0, %cst) {T = f32}
100             : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
101          %dim_to_reduce = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>}
102             : () -> tensor<2xi32>
103          %all = "tf.All"(%less, %dim_to_reduce) {keep_dims = false}
104             : (tensor<?x?xi1>, tensor<2xi32>) -> tensor<i1>
105          func.return %all : tensor<i1>
106        }"""
107      compiled = jitrt.compile(mlir_function, 'test', specialize)
108
109      d0 = np.random.randint(1, 100)
110      d1 = np.random.randint(1, 100)
111
112      arg0 = np.random.uniform(2.0, 10.0, size=(d0, d1)).astype(np.float32)
113
114      np_res = arg0
115      while np.all(np.less(np_res, 100)):
116        np_res = np_res * np_res
117
118      [res] = jitrt.execute(compiled, [arg0])
119      np.testing.assert_allclose(res, np_res)
120
121
122if __name__ == '__main__':
123  test.main()
124