1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.MatMul JIT compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22 23def matmul(): 24 return """ 25 func.func @matmul(%arg0: tensor<?x?xf32>, 26 %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { 27 %0 = "tf.MatMul"(%arg0, %arg1) { 28 transpose_a = false, 29 transpose_b = false 30 } : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 31 func.return %0 : tensor<?x?xf32> 32 }""" 33 34 35jitrt = tf_jitrt.TfJitRtExecutor() 36 37 38def verify_matmul(compiled, m, k, n): 39 lhs = np.random.uniform(0.0, 1.0, size=(m, k)).astype(np.float32) 40 rhs = np.random.uniform(0.0, 1.0, size=(k, n)).astype(np.float32) 41 42 [res] = jitrt.execute(compiled, [lhs, rhs]) 43 np.testing.assert_allclose(res, np.matmul(lhs, rhs), rtol=1e-05) 44 45 46class TfMatMulTest(test.TestCase): 47 48 # Matmul: [1, k] x [k, 1] 49 def test_dot_product(self): 50 compiled = jitrt.compile(matmul(), "matmul") 51 for _ in range(100): 52 k = np.random.randint(1, 10) 53 verify_matmul(compiled, 1, k, 1) 54 55 # Matmul: [1, k] x [k, n] 56 def test_vec_mat(self): 57 compiled = jitrt.compile(matmul(), "matmul") 58 for _ in range(100): 59 k = np.random.randint(1, 10) 60 n = np.random.randint(1, 10) 61 verify_matmul(compiled, 1, k, n) 62 63 # Matmul: [n, k] x [k, 1] 64 def test_mat_vec(self): 65 compiled = jitrt.compile(matmul(), "matmul") 66 for _ in range(100): 67 m = np.random.randint(1, 10) 68 k = np.random.randint(1, 10) 69 verify_matmul(compiled, m, k, 1) 70 71 # Matmul: [m, k] x [k, n] 72 def test_matmul(self): 73 compiled = jitrt.compile(matmul(), "matmul") 74 for _ in range(100): 75 m = np.random.randint(1, 10) 76 k = np.random.randint(1, 10) 77 n = np.random.randint(1, 10) 78 verify_matmul(compiled, m, k, n) 79 80 81if __name__ == "__main__": 82 np.random.seed(0) 83 test.main() 84