1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28jitrt = tf_jitrt.TfJitRtExecutor() 29 30 31class TfSelect(test.TestCase): 32 33 def test_select_1d(self): 34 for specialize in specializations: 35 mlir_function = """ 36 func.func @test(%arg0: tensor<?xf32>) 37 -> (tensor<?xf32>, tensor<?xi1>, tensor<?xf32>) 38 { 39 %c = "tf.Const"() {value = dense<0.0> : tensor<f32>} 40 : () -> tensor<f32> 41 %0 = "tf.ZerosLike"(%arg0) 42 : (tensor<?xf32>) -> tensor<?xf32> 43 %1 = "tf.Less"(%arg0, %c) 44 : (tensor<?xf32>, tensor<f32>) -> tensor<?xi1> 45 %2 = "tf.Select"(%1, %0, %arg0) 46 : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 47 func.return %0, %1, %2 : tensor<?xf32>, tensor<?xi1>, tensor<?xf32> 48 }""" 49 50 compiled = jitrt.compile(mlir_function, 'test', specialize) 51 52 d0 = np.random.randint(1, 10) 53 arg0 = np.random.uniform(0, 10.0, size=(d0)).astype(np.float32) 54 55 [zeros, less, res] = jitrt.execute(compiled, [arg0]) 56 np.testing.assert_allclose(zeros, np.zeros_like(arg0), atol=0.0) 57 np.testing.assert_allclose(less, np.less(arg0, 0.0), atol=0.0) 58 np.testing.assert_allclose(res, np.clip(arg0, 0.0, None), atol=0.0) 59 60 61if __name__ == '__main__': 62 test.main() 63