• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tensorflow.compiler.mlir.tfr.integration.node_expansion."""
15
16import os
17
18from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import load_library
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import gen_resource_variable_ops
24from tensorflow.python.ops import nn_ops
25from tensorflow.python.platform import test
26
27_lib_dir = os.path.dirname(gen_composite_ops.__file__)
28_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace(
29    '.py', '.so')
30load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
31
32
33class NodeExpansionTest(test.TestCase):
34
35  def testAddN(self):
36    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
37    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
38    t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
39    sq1 = gen_composite_ops.my_add_n([t1])
40    sq2 = gen_composite_ops.my_add_n([t1, t2])
41    sq3 = gen_composite_ops.my_add_n([t1, t2, t3])
42    self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4])
43    self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8])
44    self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12])
45
46  def testBiasedDense(self):
47    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
48    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
49    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
50    sq = gen_composite_ops.my_biased_dense(t1, t2, t3)
51    self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12])
52
53  def testBiasedDenseRelu(self):
54    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
55    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
56    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
57    sq = gen_composite_ops.my_biased_dense(t1, t2, t3, act='relu')
58    self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12])
59
60  def testWithKnownKernel(self):
61
62    def biasd_dense_elu(x, y, z):
63      dot = gen_composite_ops.my_biased_dense(x, y, z)
64      return nn_ops.elu(dot)  # with known kernel, should not expand.
65
66    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
67    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
68    t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]])
69    sq = biasd_dense_elu(t1, t2, t3)
70    self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12])
71
72  # Regression test for an issue where VarHandleOp wasn't being properly
73  # imported into MLIR for "no-op" node expansion.
74  def testVarHandleOp(self):
75    x = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
76
77    # Note: we purposely make multiple calls to VarHandleOp to exercise the
78    # cached kernal lookup path that was exhibiting the VarHandleOp import
79    # issue.
80    unused_ = gen_resource_variable_ops.VarHandleOp(
81        dtype=dtypes.float32, shape=[3, 2])
82    handle = gen_resource_variable_ops.VarHandleOp(
83        dtype=dtypes.float32, shape=[3, 2])
84    gen_resource_variable_ops.AssignVariableOp(resource=handle, value=x)
85    self.assertAllEqual(
86        x,
87        gen_resource_variable_ops.ReadVariableOp(
88            resource=handle, dtype=dtypes.float32))
89
90
91if __name__ == '__main__':
92  os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources'
93  ops.enable_eager_execution()
94  test.main()
95