1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28jitrt = tf_jitrt.TfJitRtExecutor() 29 30 31class TfReshapeTest(test.TestCase): 32 33 def test_reshape_unknown_1d(self): 34 for specialize in specializations: 35 mlir_function = """ 36 func.func @test(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) 37 -> tensor<?x?xf32> { 38 %0 = "tf.Reshape"(%arg0, %arg1) 39 : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 40 func.return %0 : tensor<?x?xf32> 41 }""" 42 43 compiled = jitrt.compile(mlir_function, 'test', specialize) 44 45 d0 = np.random.randint(1, 10) * 2 46 47 arg0 = np.random.uniform(0, 10.0, size=(d0)).astype(np.float32) 48 49 shape = np.array([2, d0 / 2]).astype(np.int32) 50 [res] = jitrt.execute(compiled, [arg0, shape]) 51 np.testing.assert_allclose(res, np.reshape(arg0, shape), atol=0.0) 52 53 shape = np.array([2, -1]).astype(np.int32) 54 [res] = jitrt.execute(compiled, [arg0, shape]) 55 np.testing.assert_allclose(res, np.reshape(arg0, shape), atol=0.0) 56 57 with self.assertRaises(RuntimeError): 58 shape = np.array([30, -1]).astype(np.int32) 59 [res] = jitrt.execute(compiled, [arg0, shape]) 60 61 def test_reshape_unknown_2d(self): 62 for specialize in specializations: 63 mlir_function = """ 64 func.func @test(%arg0: tensor<?x?xf32>, %arg1: tensor<1xi32>) 65 -> tensor<?xf32> { 66 %0 = "tf.Reshape"(%arg0, %arg1) 67 : (tensor<?x?xf32>, tensor<1xi32>) -> tensor<?xf32> 68 func.return %0 : tensor<?xf32> 69 }""" 70 71 compiled = jitrt.compile(mlir_function, 'test', specialize) 72 73 d0 = np.random.randint(1, 10) * 2 74 d1 = np.random.randint(1, 10) * 2 75 76 arg0 = np.random.uniform(0, 10.0, size=(d0, d1)).astype(np.float32) 77 78 shape = np.array([d0 * d1]).astype(np.int32) 79 [res] = jitrt.execute(compiled, [arg0, shape]) 80 np.testing.assert_allclose(res, np.reshape(arg0, shape), atol=0.0) 81 82 shape = np.array([-1]).astype(np.int32) 83 [res] = jitrt.execute(compiled, [arg0, shape]) 84 np.testing.assert_allclose(res, np.reshape(arg0, shape), atol=0.0) 85 86 with self.assertRaises(RuntimeError): 87 shape = np.array([d0]).astype(np.int32) 88 [res] = jitrt.execute(compiled, [arg0, shape]) 89 90 def test_reshape_zero_dim(self): 91 for specialize in specializations: 92 mlir_function = """ 93 func.func @test(%arg0: tensor<?xf32>, %arg1: tensor<1xi32>) 94 -> tensor<?xf32> { 95 %0 = "tf.Reshape"(%arg0, %arg1) 96 : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32> 97 func.return %0 : tensor<?xf32> 98 }""" 99 100 compiled = jitrt.compile(mlir_function, 'test', specialize) 101 102 empty = np.array([]).astype(np.float32) 103 104 zero = np.array([0]).astype(np.int32) 105 [res] = jitrt.execute(compiled, [empty, zero]) 106 np.testing.assert_equal(res.shape, [0]) 107 108 neg1 = np.array([-1]).astype(np.int32) 109 [res] = jitrt.execute(compiled, [empty, neg1]) 110 np.testing.assert_equal(res.shape, [0]) 111 112 with self.assertRaises(RuntimeError): 113 neg2 = np.array([-2]).astype(np.int32) 114 [res] = jitrt.execute(compiled, [empty, neg2]) 115 116 with self.assertRaises(RuntimeError): 117 one = np.array([1]).astype(np.int32) 118 [res] = jitrt.execute(compiled, [empty, one]) 119 120 def test_reshape_zero_dim_3d(self): 121 for specialize in specializations: 122 mlir_function = """ 123 func.func @test(%arg0: tensor<?xf32>, %arg1: tensor<3xi32>) 124 -> tensor<?x?x?xf32> { 125 %0 = "tf.Const"() { value = dense<[3, 0, 5]> : tensor<3xi32> } 126 : () -> tensor<3xi32> 127 %1 = "tf.Reshape"(%arg0, %0) 128 : (tensor<?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> 129 %2 = "tf.Reshape"(%1, %arg1) 130 : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> 131 func.return %2 : tensor<?x?x?xf32> 132 }""" 133 134 compiled = jitrt.compile(mlir_function, 'test', specialize) 135 136 empty = np.array([]).astype(np.float32) 137 138 shape = np.array([3, 0, -1]).astype(np.int32) 139 [res] = jitrt.execute(compiled, [empty, shape]) 140 # TODO(kramerb): This should be [3, 0, 5] 141 np.testing.assert_equal(res.shape, [3, 0, 0]) 142 143 with self.assertRaises(RuntimeError): 144 shape = np.array([3, -1, -1]).astype(np.int32) 145 [res] = jitrt.execute(compiled, [empty, shape]) 146 147if __name__ == '__main__': 148 test.main() 149