1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for numerical correctness of tf.math operations.""" 16 17import enum 18import numpy as np 19 20from absl import flags 21from absl.testing import parameterized 22from tensorflow import math 23 24from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 25from tensorflow.python.platform import test 26 27jitrt = tf_jitrt.TfJitRtExecutor() 28 29FLAGS = flags.FLAGS 30flags.DEFINE_integer('iters', '1000', 'Number of test iterations') 31flags.DEFINE_integer('vector_size', '128', 'Iteration vector size') 32 33 34# We cannot read flags from the parameterized test, so we cannot pass a value 35# for rtol to the test function (rtol does depend on the 'vector_size' flag). 36# Pass this enum instead so that we read the flags only in the test function. 37class Rtol(enum.Enum): 38 ZERO = 0 39 BASE = 1 40 AVX2 = 2 41 42 43def mlir_func_1d(op_name): 44 return f""" 45 func.func @test(%arg0: tensor<?xf32>) -> tensor<?xf32> {{ 46 %0 = "tf.{op_name}"(%arg0): (tensor<?xf32>) -> tensor<?xf32> 47 func.return %0 : tensor<?xf32> 48 }}""" 49 50 51def test_1d(op_name, fn, vectorize=False, lb=-1.0, ub=1.0, rtol_enum=Rtol.BASE): 52 compiled = jitrt.compile(mlir_func_1d(op_name), 'test', vectorize=vectorize) 53 rtols = {} 54 rtols[Rtol.ZERO] = 0.0 55 # Not all approximations are identical to TF's. 56 rtols[Rtol.BASE] = 1e-6 57 # For some ops we can match TF with the right build flags. 58 # Note that vector size also matters: for vectors whose size is not a multiple 59 # of the machine's vector length, Eigen (and therefore TF) computes some 60 # elements differently (e.g. via libc). 61 rtols[Rtol.AVX2] = rtols[Rtol.BASE] 62 # Use 16 as the machine vector's length to be both simple and future-proof. 63 if jitrt.built_with('AVX2') and FLAGS.vector_size % 16 == 0: 64 rtols[Rtol.AVX2] = 0.0 65 66 rtol = rtols[rtol_enum] 67 68 for _ in range(FLAGS.iters): 69 arg = np.random.uniform(lb, ub, size=(FLAGS.vector_size)).astype(np.float32) 70 71 [res] = jitrt.execute(compiled, [arg]) 72 np.testing.assert_allclose(res, fn(arg), rtol=rtol, atol=1e-7) 73 74 75class TfMathOpsTest(parameterized.TestCase): 76 @parameterized.named_parameters( 77 # Note: for now we are testing for identical results to TF (and therefore 78 # Eigen). In the short term, this will work because Eigen's approximations 79 # don't change too often. However, in the long term could become a 80 # maintenance burden. 81 # TODO(ecg): relax tolerances to accommodate for changes in Eigen, and add 82 # a flag to control the minimum tolerance, so that we can manually check 83 # for identical results to Eigen. 84 ('exp_scalar', 'Exp', math.exp, False, Rtol.AVX2), 85 ('exp_vector', 'Exp', math.exp, True, Rtol.AVX2), 86 ('reciprocal_scalar', 'Reciprocal', math.reciprocal, False, Rtol.BASE), 87 ('reciprocal_vector', 'Reciprocal', math.reciprocal, True, Rtol.BASE), 88 # Rsqrt: The AVX2 intrinsic is only emitted with vectorization. 89 ('rsqrt_scalar', 'Rsqrt', math.rsqrt, False, Rtol.BASE), 90 ('rsqrt_vector', 'Rsqrt', math.rsqrt, True, Rtol.AVX2), 91 ('tanh_scalar', 'Tanh', math.tanh, False, Rtol.AVX2), 92 ('tanh_vector', 'Tanh', math.tanh, True, Rtol.AVX2), 93 ) 94 95 def test_op(self, op_name, fn, vectorize, rtol_enum): 96 test_1d(op_name, fn, vectorize=vectorize, rtol_enum=rtol_enum) 97 98if __name__ == '__main__': 99 np.random.seed(0) 100 test.main() 101