• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for numerical correctness of tf.math operations."""
16
17import enum
18import numpy as np
19
20from absl import flags
21from absl.testing import parameterized
22from tensorflow import math
23
24from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
25from tensorflow.python.platform import test
26
27jitrt = tf_jitrt.TfJitRtExecutor()
28
29FLAGS = flags.FLAGS
30flags.DEFINE_integer('iters', '1000', 'Number of test iterations')
31flags.DEFINE_integer('vector_size', '128', 'Iteration vector size')
32
33
34# We cannot read flags from the parameterized test, so we cannot pass a value
35# for rtol to the test function (rtol does depend on the 'vector_size' flag).
36# Pass this enum instead so that we read the flags only in the test function.
37class Rtol(enum.Enum):
38  ZERO = 0
39  BASE = 1
40  AVX2 = 2
41
42
43def mlir_func_1d(op_name):
44  return f"""
45  func.func @test(%arg0: tensor<?xf32>) -> tensor<?xf32> {{
46    %0 = "tf.{op_name}"(%arg0): (tensor<?xf32>) -> tensor<?xf32>
47    func.return %0 : tensor<?xf32>
48  }}"""
49
50
51def test_1d(op_name, fn, vectorize=False, lb=-1.0, ub=1.0, rtol_enum=Rtol.BASE):
52  compiled = jitrt.compile(mlir_func_1d(op_name), 'test', vectorize=vectorize)
53  rtols = {}
54  rtols[Rtol.ZERO] = 0.0
55  # Not all approximations are identical to TF's.
56  rtols[Rtol.BASE] = 1e-6
57  # For some ops we can match TF with the right build flags.
58  # Note that vector size also matters: for vectors whose size is not a multiple
59  # of the machine's vector length, Eigen (and therefore TF) computes some
60  # elements differently (e.g. via libc).
61  rtols[Rtol.AVX2] = rtols[Rtol.BASE]
62  # Use 16 as the machine vector's length to be both simple and future-proof.
63  if jitrt.built_with('AVX2') and FLAGS.vector_size % 16 == 0:
64    rtols[Rtol.AVX2] = 0.0
65
66  rtol = rtols[rtol_enum]
67
68  for _ in range(FLAGS.iters):
69    arg = np.random.uniform(lb, ub, size=(FLAGS.vector_size)).astype(np.float32)
70
71    [res] = jitrt.execute(compiled, [arg])
72    np.testing.assert_allclose(res, fn(arg), rtol=rtol, atol=1e-7)
73
74
75class TfMathOpsTest(parameterized.TestCase):
76  @parameterized.named_parameters(
77      # Note: for now we are testing for identical results to TF (and therefore
78      # Eigen). In the short term, this will work because Eigen's approximations
79      # don't change too often. However, in the long term could become a
80      # maintenance burden.
81      # TODO(ecg): relax tolerances to accommodate for changes in Eigen, and add
82      # a flag to control the minimum tolerance, so that we can manually check
83      # for identical results to Eigen.
84      ('exp_scalar', 'Exp', math.exp, False, Rtol.AVX2),
85      ('exp_vector', 'Exp', math.exp, True, Rtol.AVX2),
86      ('reciprocal_scalar', 'Reciprocal', math.reciprocal, False, Rtol.BASE),
87      ('reciprocal_vector', 'Reciprocal', math.reciprocal, True, Rtol.BASE),
88      # Rsqrt: The AVX2 intrinsic is only emitted with vectorization.
89      ('rsqrt_scalar', 'Rsqrt', math.rsqrt, False, Rtol.BASE),
90      ('rsqrt_vector', 'Rsqrt', math.rsqrt, True, Rtol.AVX2),
91      ('tanh_scalar', 'Tanh', math.tanh, False, Rtol.AVX2),
92      ('tanh_vector', 'Tanh', math.tanh, True, Rtol.AVX2),
93  )
94
95  def test_op(self, op_name, fn, vectorize, rtol_enum):
96    test_1d(op_name, fn, vectorize=vectorize, rtol_enum=rtol_enum)
97
98if __name__ == '__main__':
99  np.random.seed(0)
100  test.main()
101