• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22jitrt = tf_jitrt.TfJitRtExecutor()
23
24specializations = [
25    tf_jitrt.Specialization.ENABLED,
26    tf_jitrt.Specialization.DISABLED,
27    tf_jitrt.Specialization.ALWAYS,
28]
29
30
31class TfBinaryBcastTest(test.TestCase):
32
33  def test_bcast_2d_1d(self):
34    mlir_function = """
35      func.func @test(%arg0: tensor<?x4xf32>,
36                      %arg1: tensor<4xf32>,
37                      %arg2: tensor<4xf32>) -> tensor<?x4xf32> {
38        %0 = "tf.Log1p"(%arg0)
39             : (tensor<?x4xf32>) -> tensor<?x4xf32>
40        %1 = "tf.Sub"(%0, %arg1)
41             : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32>
42        %2 = "tf.Mul"(%1, %arg2)
43             : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32>
44        %3 = "tf.Atan2"(%2, %arg2)
45             : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32>
46        func.return %3 : tensor<?x4xf32>
47      }"""
48
49    n = np.random.randint(1, 10)
50
51    arg0 = np.random.uniform(0, 10.0, size=(n, 4)).astype(np.float32)
52    arg1 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32)
53    arg2 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32)
54
55    for specialize in specializations:
56      for vectorize in [True, False]:
57        compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize)
58
59        [res] = jitrt.execute(compiled, [arg0, arg1, arg2])
60        ref = np.arctan2((np.log1p(arg0) - arg1) * arg2, arg2)
61        np.testing.assert_allclose(res, ref, atol=1e-04)
62
63  def test_bcast_2d_2d(self):
64    mlir_function = """
65      func.func @test(%arg0: tensor<?x?xf32>,
66                      %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
67        %0 = "tf.Mul"(%arg0, %arg1)
68             : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
69        func.return %0 : tensor<?x?xf32>
70      }"""
71
72    m = np.random.randint(1, 10)
73    n = np.random.randint(1, 10)
74
75    lhs0 = np.random.uniform(0, 10.0, size=(1, 1)).astype(np.float32)
76    lhs1 = np.random.uniform(0, 10.0, size=(1, n)).astype(np.float32)
77    lhs2 = np.random.uniform(0, 10.0, size=(m, 1)).astype(np.float32)
78    lhs3 = np.random.uniform(0, 10.0, size=(m, n)).astype(np.float32)
79
80    rhs0 = np.random.uniform(0, 10.0, size=(1, 1)).astype(np.float32)
81    rhs1 = np.random.uniform(0, 10.0, size=(1, n)).astype(np.float32)
82    rhs2 = np.random.uniform(0, 10.0, size=(m, 1)).astype(np.float32)
83    rhs3 = np.random.uniform(0, 10.0, size=(m, n)).astype(np.float32)
84
85    for specialize in specializations:
86      compiled = jitrt.compile(mlir_function, 'test', specialize)
87
88      for lhs in [lhs0, lhs1, lhs2, lhs3]:
89        for rhs in [rhs0, rhs1, rhs2, rhs3]:
90          [res] = jitrt.execute(compiled, [lhs, rhs])
91          np.testing.assert_allclose(res, lhs * rhs, atol=1e-07)
92
93  def test_bcast_2d_1d_0d(self):
94    mlir_function = """
95      func.func @compute(%arg0: tensor<?x4xf32>,
96                         %arg1: tensor<4xf32>,
97                         %arg2: tensor<f32>) -> tensor<?x4xf32> {
98        %0 = "tf.AddV2"(%arg1, %arg2)
99             : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
100        %1 = "tf.AddV2"(%arg0, %0)
101             : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32>
102        %2 = "tf.AddV2"(%1, %0)
103             : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32>
104        func.return %2 : tensor<?x4xf32>
105      }"""
106
107    for specialize in specializations:
108      compiled = jitrt.compile(mlir_function, 'compute', specialize)
109
110      arg0 = np.random.uniform(0, 10.0, size=(1, 4)).astype(np.float32)
111      arg1 = np.random.uniform(0, 10.0, size=(4)).astype(np.float32)
112      arg2 = np.random.uniform(0, 10.0, size=()).astype(np.float32)
113
114      [res] = jitrt.execute(compiled, [arg0, arg1, arg2])
115
116      # Reference implementation with numpy
117      t_0 = np.add(arg1, arg2)
118      t_1 = np.add(arg0, t_0)
119      t_2 = np.add(t_1, t_0)
120
121      np.testing.assert_allclose(res, t_2, atol=0.0)
122
123  def test_bcast_3d_3d(self):
124    mlir_function = """
125      func.func @test(%arg0: tensor<?x?x12xf32>,
126                      %arg1: tensor<?x?x12xf32>) -> tensor<?x?x12xf32> {
127        %0 = "tf.AddV2"(%arg0, %arg1)
128             : (tensor<?x?x12xf32>, tensor<?x?x12xf32>) -> tensor<?x?x12xf32>
129        func.return %0 : tensor<?x?x12xf32>
130      }"""
131
132    d0 = np.random.randint(1, 10)
133    d1 = np.random.randint(1, 10)
134
135    arg0 = np.random.uniform(0, 10.0, size=(d0, d1, 12)).astype(np.float32)
136    arg1 = np.random.uniform(0, 10.0, size=(d0, d1, 12)).astype(np.float32)
137
138    for specialize in specializations:
139      for vectorize in [True, False]:
140        compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize)
141
142        [res] = jitrt.execute(compiled, [arg0, arg1])
143        np.testing.assert_allclose(res, arg0 + arg1, atol=0.0)
144
145  def test_bcast_unranked_0d(self):
146    mlir_function = """
147      func.func @compute(%arg0: tensor<*xf32> {rt.constraint = "rank"},
148                         %arg1: tensor<f32>) -> tensor<*xf32> {
149        %0 = "tf.AddV2"(%arg0, %arg1)
150             : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
151        func.return %0 : tensor<*xf32>
152      }"""
153
154    compiled = jitrt.compile(mlir_function, 'compute')
155
156    arg0 = np.random.uniform(0, 10.0, size=(4, 4)).astype(np.float32)
157    arg1 = np.random.uniform(0, 10.0, size=()).astype(np.float32)
158
159    [res] = jitrt.execute(compiled, [arg0, arg1])
160
161    np.testing.assert_allclose(res, np.add(arg0, arg1), atol=0.0)
162
163  def test_bcast_unranked_unranked(self):
164    mlir_function = """
165      func.func @compute(%arg0: tensor<*xf32> {rt.constraint = "rank"},
166                         %arg1: tensor<*xf32> {rt.constraint = "rank"})
167          -> tensor<*xf32> {
168        %0 = "tf.AddV2"(%arg0, %arg1)
169             : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
170        func.return %0 : tensor<*xf32>
171      }"""
172
173    compiled = jitrt.compile(mlir_function, 'compute')
174
175    arg0 = np.random.uniform(0, 10.0, size=(1, 4)).astype(np.float32)
176    arg1 = np.random.uniform(0, 10.0, size=(4, 1)).astype(np.float32)
177
178    [res] = jitrt.execute(compiled, [arg0, arg1])
179
180    np.testing.assert_allclose(res, np.add(arg0, arg1), atol=0.0)
181
182  # Test that the non-broadcastable shapes error is handled at run time.
183  def test_bcast_1d_1d_error(self):
184    mlir_function = """
185      func.func @compute(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>)
186          -> tensor<?xf32> {
187        %0 = "tf.AddV2"(%arg0, %arg1)
188             : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
189        func.return %0 : tensor<?xf32>
190      }"""
191
192    arg0 = np.random.uniform(0, 10.0, size=(2)).astype(np.float32)
193    arg1 = np.random.uniform(0, 10.0, size=(3)).astype(np.float32)
194
195    for specialize in specializations:
196      compiled = jitrt.compile(mlir_function, 'compute', specialize)
197
198      with self.assertRaisesRegex(Exception, 'required broadcastable shapes'):
199        jitrt.execute(compiled, [arg0, arg1])
200
201  # Test that 0-ranked operands are correctly specialized.
202  def test_bcast_value_rank0(self):
203    mlir_function = """
204      func.func @compute(%arg0: tensor<*xi32>,
205                         %arg1: tensor<i32> {rt.constraint = "value"})
206          -> tensor<*xi32> {
207        %0 = "tf.AddV2"(%arg0, %arg1)
208             : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
209        func.return %0 : tensor<*xi32>
210      }"""
211    compiled = jitrt.compile(mlir_function, 'compute')
212    # Test that the same compiled module with two different value-specialized
213    # arguments is handled correctly.
214    tensor = np.random.uniform(0, 10.0, size=(3)).astype(np.int32)
215    rhs0 = np.random.uniform(0, 10.0, size=()).astype(np.int32)
216    rhs1 = np.random.uniform(0, 10.0, size=()).astype(np.int32)
217    [res0] = jitrt.execute(compiled, [tensor, rhs0])
218    [res1] = jitrt.execute(compiled, [tensor, rhs1])
219    np.testing.assert_allclose(res0, np.add(tensor, rhs0), atol=0.0)
220    np.testing.assert_allclose(res1, np.add(tensor, rhs1), atol=0.0)
221
222  # Test that the function does not compile when value-specializing an f32.
223  def test_bcast_value_die_if_unsinkable(self):
224    mlir_function = """
225      func.func @compute(%arg0: tensor<*xf32>,
226                    %arg1: tensor<f32> {rt.constraint = "value"})
227          -> tensor<*xf32> {
228        %0 = "tf.AddV2"(%arg0, %arg1)
229             : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
230        func.return %0 : tensor<*xf32>
231      }"""
232
233    with self.assertRaisesRegex(Exception,
234                                'cannot sink operand type: tensor<f32>'):
235      jitrt.compile(mlir_function, 'compute')
236
237
238if __name__ == '__main__':
239  test.main()
240