• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28jitrt = tf_jitrt.TfJitRtExecutor()
29
30
31class TfTransposeTest(test.TestCase):
32
33  def test_transpose_2d(self):
34    for specialize in specializations:
35      mlir_function = """
36        func.func @test(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
37          %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> }
38               : () -> tensor<2xi32>
39          %1 = "tf.Transpose"(%arg0, %0)
40               : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
41          func.return %1 : tensor<?x?xf32>
42        }"""
43
44      compiled = jitrt.compile(
45          mlir_function,
46          'test',
47          specialize,
48          vectorize=True,
49          codegen_transpose=True)
50
51      d0 = np.random.randint(1, 10)
52      d1 = np.random.randint(1, 10)
53
54      arg0 = np.random.uniform(0, 10.0, size=(d0, d1)).astype(np.float32)
55
56      [res] = jitrt.execute(compiled, [arg0])
57      np.testing.assert_allclose(res, np.transpose(arg0), atol=0.0)
58
59  def test_transpose_3d_0_2_1(self):
60    for specialize in specializations:
61      mlir_function = """
62        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
63          %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi64> }
64            : () -> tensor<3xi64>
65          %1 = "tf.Transpose"(%arg0, %0)
66            : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
67          func.return %1 : tensor<?x?x?xf32>
68        }"""
69
70      compiled = jitrt.compile(
71          mlir_function,
72          'test',
73          specialize,
74          vectorize=True,
75          codegen_transpose=True)
76
77      dim_size = 32
78      arg0 = np.arange(0, dim_size * dim_size * dim_size, 1,
79                       np.float32).reshape((dim_size, dim_size, dim_size))
80
81      [res] = jitrt.execute(compiled, [arg0])
82      np.testing.assert_array_equal(res, np.transpose(arg0, (0, 2, 1)))
83
84  def test_transpose_3d_2_0_1(self):
85    for specialize in specializations:
86      mlir_function = """
87        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
88          %0 = "tf.Const"() { value = dense<[2, 0, 1]> : tensor<3xi64> }
89            : () -> tensor<3xi64>
90          %1 = "tf.Transpose"(%arg0, %0)
91            : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
92          func.return %1 : tensor<?x?x?xf32>
93        }"""
94
95      compiled = jitrt.compile(
96          mlir_function,
97          'test',
98          specialize,
99          vectorize=True,
100          codegen_transpose=True)
101
102      dim_size = 32
103      arg0 = np.arange(0, dim_size * dim_size * dim_size, 1,
104                       np.float32).reshape((dim_size, dim_size, dim_size))
105
106      [res] = jitrt.execute(compiled, [arg0])
107      np.testing.assert_array_equal(res, np.transpose(arg0, (2, 0, 1)))
108
109  def test_transpose_3d_2_1_0(self):
110    for specialize in specializations:
111      mlir_function = """
112        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
113          %0 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi64> }
114            : () -> tensor<3xi64>
115          %1 = "tf.Transpose"(%arg0, %0)
116            : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
117          func.return %1 : tensor<?x?x?xf32>
118        }"""
119
120      compiled = jitrt.compile(
121          mlir_function,
122          'test',
123          specialize,
124          vectorize=True,
125          codegen_transpose=True)
126
127      dim_size = 32
128      arg0 = np.arange(0, dim_size * dim_size * dim_size, 1,
129                       np.float32).reshape((dim_size, dim_size, dim_size))
130
131      [res] = jitrt.execute(compiled, [arg0])
132      np.testing.assert_array_equal(res, np.transpose(arg0, (2, 1, 0)))
133
134  def test_transpose_3d_1_2_0(self):
135    for specialize in specializations:
136      mlir_function = """
137        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
138          %0 = "tf.Const"() { value = dense<[1, 2, 0]> : tensor<3xi64> }
139            : () -> tensor<3xi64>
140          %1 = "tf.Transpose"(%arg0, %0)
141            : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
142          func.return %1 : tensor<?x?x?xf32>
143        }"""
144
145      compiled = jitrt.compile(
146          mlir_function,
147          'test',
148          specialize,
149          vectorize=True,
150          codegen_transpose=True)
151
152      dim_size = 32
153      arg0 = np.arange(0, dim_size * dim_size * dim_size, 1,
154                       np.float32).reshape((dim_size, dim_size, dim_size))
155
156      [res] = jitrt.execute(compiled, [arg0])
157      np.testing.assert_array_equal(res, np.transpose(arg0, (1, 2, 0)))
158
159  def test_transpose_3d_1_0_2(self):
160    for specialize in specializations:
161      mlir_function = """
162        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
163          %0 = "tf.Const"() { value = dense<[1, 0, 2]> : tensor<3xi64> }
164            : () -> tensor<3xi64>
165          %1 = "tf.Transpose"(%arg0, %0)
166            : (tensor<?x?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
167          func.return %1 : tensor<?x?x?xf32>
168        }"""
169
170      compiled = jitrt.compile(
171          mlir_function,
172          'test',
173          specialize,
174          vectorize=True,
175          codegen_transpose=True)
176
177      dim_size = 32
178      arg0 = np.arange(0, dim_size * dim_size * dim_size, 1,
179                       np.float32).reshape((dim_size, dim_size, dim_size))
180
181      [res] = jitrt.execute(compiled, [arg0])
182      np.testing.assert_array_equal(res, np.transpose(arg0, (1, 0, 2)))
183
184  def test_double_transpose_3d(self):
185    for specialize in specializations:
186      mlir_function = """
187        func.func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
188          %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> }
189               : () -> tensor<3xi32>
190          %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> }
191               : () -> tensor<3xi32>
192          %2 = "tf.Transpose"(%arg0, %0)
193               : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
194          %3 = "tf.Transpose"(%2, %1)
195               : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
196          func.return %3 : tensor<?x?x?xf32>
197        }"""
198
199      compiled = jitrt.compile(
200          mlir_function,
201          'test',
202          specialize,
203          vectorize=True,
204          codegen_transpose=True)
205
206      d0 = np.random.randint(1, 10)
207      d1 = np.random.randint(1, 10)
208      d2 = np.random.randint(1, 10)
209
210      arg0 = np.random.uniform(0, 10.0, size=(d0, d1, d2)).astype(np.float32)
211
212      [res] = jitrt.execute(compiled, [arg0])
213      ref = np.transpose(np.transpose(arg0, (0, 2, 1)), (2, 1, 0))
214      np.testing.assert_allclose(res, ref, atol=0.0)
215
216  # Without value specialization, the below tf.Transpose won't compile because
217  # the permutation vector must be statically shaped.
218  def test_transpose_value_specialization_i32(self):
219    mlir_function = """
220      func.func @compute(%arg0: tensor<*xf32>,
221                    %arg1: tensor<?xi32> {rt.constraint = "value"})
222          -> tensor<*xf32> {
223        %0 = "tf.Transpose"(%arg0, %arg1)
224             : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
225        func.return %0 : tensor<*xf32>
226      }"""
227    compiled = jitrt.compile(mlir_function, 'compute')
228    tensor = np.random.uniform(0, 10.0, size=(3, 3)).astype(np.float32)
229    perm0 = np.array([1, 0]).astype(np.int32)
230    perm1 = np.array([0, 1]).astype(np.int32)
231
232    # Test that the same compiled module with two different value-specialized
233    # arguments is handled correctly, i.e. it is specialized twice.
234    [res0] = jitrt.execute(compiled, [tensor, perm0])
235    [res1] = jitrt.execute(compiled, [tensor, perm1])
236    np.testing.assert_allclose(res0, np.transpose(tensor, perm0), atol=0.0)
237    np.testing.assert_allclose(res1, np.transpose(tensor, perm1), atol=0.0)
238
239  # Test value specialization of two i64 operands.
240  def test_transpose_value_specialization_i64(self):
241    mlir_function = """
242      func.func @compute(%arg0: tensor<*xf32>,
243                    %arg1: tensor<?xi64> {rt.constraint = "value"},
244                    %arg2: tensor<?xi64> {rt.constraint = "value"})
245          -> tensor<*xf32> {
246        %0 = "tf.Transpose"(%arg0, %arg1)
247             : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32>
248        %1 = "tf.Transpose"(%0, %arg2)
249             : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32>
250        func.return %1 : tensor<*xf32>
251      }"""
252    compiled = jitrt.compile(mlir_function, 'compute')
253    tensor = np.random.uniform(0, 10.0, size=(3, 3)).astype(np.float32)
254    perm0 = np.array([1, 0]).astype(np.int64)
255    perm1 = np.array([0, 1]).astype(np.int64)
256
257    [res] = jitrt.execute(compiled, [tensor, perm0, perm1])
258    np.testing.assert_allclose(
259        res, np.transpose(np.transpose(tensor, perm0), perm1), atol=0.0)
260
261  # Test that without the value constraint the function cannot compile
262  # because the permutation vector is not statically shaped.
263  def test_transpose_die_without_value_specialization(self):
264    mlir_function = """
265      func.func @compute(%arg0: tensor<*xf32>,
266                    %arg1: tensor<?xi64>) -> tensor<*xf32> {
267        %0 = "tf.Transpose"(%arg0, %arg1)
268             : (tensor<*xf32>, tensor<?xi64>) -> tensor<*xf32>
269        func.return %0 : tensor<*xf32>
270      }"""
271    try:
272      jitrt.compile(mlir_function, 'compute')
273    except Exception:  # pylint: disable=broad-except
274      return
275    raise RuntimeError('Compilation should have failed')
276
277
278if __name__ == '__main__':
279  test.main()
280