• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.MatMul JIT compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22
23def matmul():
24  return """
25  func.func @matmul(%arg0: tensor<?x?xf32>,
26               %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
27    %0 = "tf.MatMul"(%arg0, %arg1) {
28           transpose_a = false,
29           transpose_b = false
30         } : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
31    func.return %0 : tensor<?x?xf32>
32  }"""
33
34
35jitrt = tf_jitrt.TfJitRtExecutor()
36
37
38def verify_matmul(compiled, m, k, n):
39  lhs = np.random.uniform(0.0, 1.0, size=(m, k)).astype(np.float32)
40  rhs = np.random.uniform(0.0, 1.0, size=(k, n)).astype(np.float32)
41
42  [res] = jitrt.execute(compiled, [lhs, rhs])
43  np.testing.assert_allclose(res, np.matmul(lhs, rhs), rtol=1e-05)
44
45
46class TfMatMulTest(test.TestCase):
47
48  # Matmul: [1, k] x [k, 1]
49  def test_dot_product(self):
50    compiled = jitrt.compile(matmul(), "matmul")
51    for _ in range(100):
52      k = np.random.randint(1, 10)
53      verify_matmul(compiled, 1, k, 1)
54
55  # Matmul: [1, k] x [k, n]
56  def test_vec_mat(self):
57    compiled = jitrt.compile(matmul(), "matmul")
58    for _ in range(100):
59      k = np.random.randint(1, 10)
60      n = np.random.randint(1, 10)
61      verify_matmul(compiled, 1, k, n)
62
63  # Matmul: [n, k] x [k, 1]
64  def test_mat_vec(self):
65    compiled = jitrt.compile(matmul(), "matmul")
66    for _ in range(100):
67      m = np.random.randint(1, 10)
68      k = np.random.randint(1, 10)
69      verify_matmul(compiled, m, k, 1)
70
71  # Matmul: [m, k] x [k, n]
72  def test_matmul(self):
73    compiled = jitrt.compile(matmul(), "matmul")
74    for _ in range(100):
75      m = np.random.randint(1, 10)
76      k = np.random.randint(1, 10)
77      n = np.random.randint(1, 10)
78      verify_matmul(compiled, m, k, n)
79
80
81if __name__ == "__main__":
82  np.random.seed(0)
83  test.main()
84