• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tensorflow.compiler.mlir.tfr.integrattion.graph_decompose."""
15
16import os
17
18from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops
19from tensorflow.python.eager import def_function
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import load_library
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import nn_ops
24from tensorflow.python.platform import test
25
26_lib_dir = os.path.dirname(gen_composite_ops.__file__)
27_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace(
28    '.py', '.so')
29load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
30
31
32class GraphDecomposeTest(test.TestCase):
33
34  def testAddN(self):
35    add = def_function.function(gen_composite_ops.my_add_n)
36    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
37    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
38    t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
39    sq1 = add([t1])
40    sq2 = add([t1, t2])
41    sq3 = add([t1, t2, t3])
42    self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4])
43    self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8])
44    self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12])
45
46  def testBiasedDense(self):
47    biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
48    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
49    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
50    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
51    sq = biased_dense(t1, t2, t3)
52    self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12])
53
54  def testBiasedDenseRelu(self):
55    biased_dense = def_function.function(gen_composite_ops.my_biased_dense)
56    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
57    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
58    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
59    sq = biased_dense(t1, t2, t3, act='relu')
60    self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12])
61
62  def testWithKnownKernel(self):
63
64    @def_function.function
65    def biasd_dense_elu(x, y, z):
66      dot = gen_composite_ops.my_biased_dense(x, y, z)
67      return nn_ops.elu(dot)  # with known kernel, should not expand.
68
69    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
70    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
71    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
72    sq = biasd_dense_elu(t1, t2, t3)
73    self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12])
74
75
76if __name__ == '__main__':
77  os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
78  ops.enable_eager_execution()
79  test.main()
80