1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24 25class TfBroadcastToTest(test.TestCase): 26 27 def test_broadcast_return(self): 28 mlir_function = """ 29 func.func @test(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) 30 -> (tensor<?x?xf32>, tensor<?x?xf32>) { 31 %1 = "tf.BroadcastTo"(%arg0, %arg1) 32 : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 33 %2 = "tf.Add"(%1, %1) 34 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 35 func.return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32> 36 }""" 37 38 compiled = jitrt.compile(mlir_function, 'test') 39 40 arg0 = np.random.uniform(0, 10.0, size=1).astype(np.float32) 41 arg1 = np.random.uniform(0, 10, size=2).astype(np.int32) 42 43 [res1, res2] = jitrt.execute(compiled, [arg0, arg1]) 44 np.testing.assert_allclose(res1, np.broadcast_to(arg0, arg1), atol=0.0) 45 np.testing.assert_allclose(res2, np.broadcast_to(arg0, arg1) * 2, atol=0.0) 46 47 48if __name__ == '__main__': 49 test.main() 50