1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22jitrt = tf_jitrt.TfJitRtExecutor() 23 24 25def softmax(x): 26 z = x - np.max(x, axis=-1, keepdims=True) 27 numerator = np.exp(z) 28 denominator = np.sum(numerator, axis=-1, keepdims=True) 29 result = numerator / denominator 30 return result 31 32 33class TfSoftmaxTest(test.TestCase): 34 35 def test_dynamic_softmax(self): 36 mlir_function = """ 37 func.func @test(%input: tensor<?x?xf32>) -> tensor<?x?xf32> { 38 %0 = "tf.Softmax"(%input) : (tensor<?x?xf32>) -> tensor<?x?xf32> 39 func.return %0 : tensor<?x?xf32> 40 }""" 41 42 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 43 44 arg0 = np.random.uniform(1, 5, size=(8, 8)).astype(np.float32) 45 46 [res] = jitrt.execute(compiled, [arg0]) 47 np.testing.assert_allclose(res, softmax(arg0), atol=0.00001) 48 49 def test_static_softmax(self): 50 mlir_function = """ 51 func.func @test(%input: tensor<10x8xf32>) -> tensor<10x8xf32> { 52 %0 = "tf.Softmax"(%input) : (tensor<10x8xf32>) -> tensor<10x8xf32> 53 func.return %0 : tensor<10x8xf32> 54 }""" 55 56 compiled = jitrt.compile(mlir_function, 'test', vectorize=True) 57 58 arg0 = np.random.uniform(1, 5, size=(10, 8)).astype(np.float32) 59 60 [res] = jitrt.execute(compiled, [arg0]) 61 np.testing.assert_allclose(res, softmax(arg0), atol=0.00001) 62 63 64if __name__ == '__main__': 65 test.main() 66