• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for operator dispatch."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import gen_math_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.platform import googletest
26from tensorflow.python.util import dispatch
27from tensorflow.python.util.tf_export import tf_export
28
29
30class CustomTensor(object):
31  """A fake composite tensor class, for testing type-based dispatching."""
32
33  def __init__(self, tensor, score):
34    self.tensor = ops.convert_to_tensor(tensor)
35    self.score = score
36
37
38@tf_export("test_op")
39@dispatch.add_dispatch_support
40def test_op(x, y, z):
41  """A fake op for testing dispatch of Python ops."""
42  return x + (2 * y) + (3 * z)
43
44
45@test_util.run_all_in_graph_and_eager_modes
46class DispatchTest(test_util.TensorFlowTestCase):
47
48  def testAddDispatchForTypes_With_CppOp(self):
49    original_handlers = gen_math_ops.add._tf_dispatchers[:]
50
51    # Override the behavior of gen_math_ops.add.
52    @dispatch.dispatch_for_types(gen_math_ops.add, CustomTensor)
53    def custom_add(x, y, name=None):  # pylint: disable=unused-variable
54      return CustomTensor(gen_math_ops.add(x.tensor, y.tensor, name),
55                          (x.score+y.score) / 2.0)
56    self.assertEqual(len(math_ops.add._tf_dispatchers),
57                     len(original_handlers) + 1)
58
59    # Test that we see the overridden behavior when using CustomTensors.
60    x = CustomTensor([1, 2, 3], 2.0)
61    y = CustomTensor([7, 8, 2], 0.0)
62    x_plus_y = gen_math_ops.add(x, y)
63    self.assertAllEqual(self.evaluate(x_plus_y.tensor), [8, 10, 5])
64    self.assertNear(x_plus_y.score, 1.0, 0.001)
65
66    # Test that we still get the right behavior when using normal Tensors.
67    a = [1, 2, 3]
68    b = [4, 5, 6]
69    a_plus_b = gen_math_ops.add(a, b)
70    self.assertAllEqual(a_plus_b, [5, 7, 9])
71
72    # Test that we still get a TypeError or ValueError if we pass some
73    # type that's not supported by any dispatcher.
74    with self.assertRaises((TypeError, ValueError)):
75      gen_math_ops.add(a, None)
76
77    # Clean up
78    gen_math_ops.add._tf_dispatchers = original_handlers
79
80  def testAddDispatchForTypes_With_PythonOp(self):
81    original_handlers = test_op._tf_dispatchers[:]
82
83    @dispatch.dispatch_for_types(test_op, CustomTensor)
84    def override_for_test_op(x, y, z):  # pylint: disable=unused-variable
85      return CustomTensor(test_op(x.tensor, y.tensor, z.tensor),
86                          (x.score + y.score + z.score) / 3.0)
87
88    x = CustomTensor([1, 2, 3], 0.2)
89    y = CustomTensor([7, 8, 2], 0.4)
90    z = CustomTensor([0, 1, 2], 0.6)
91
92    result = test_op(x, y, z)
93    self.assertAllEqual(self.evaluate(result.tensor), [15, 21, 13])
94    self.assertNear(result.score, 0.4, 0.001)
95
96    # Clean up
97    test_op._tf_dispatchers = original_handlers
98
99  def testDispatchForTypes_SignatureMismatch(self):
100    with self.assertRaisesRegexp(AssertionError, "The decorated function's "
101                                 "signature must exactly match.*"):
102      @dispatch.dispatch_for_types(test_op, CustomTensor)
103      def override_for_test_op(a, b, c):  # pylint: disable=unused-variable
104        return CustomTensor(test_op(a.tensor, b.tensor, c.tensor),
105                            (a.score + b.score + c.score) / 3.0)
106
107  def testDispatchForTypes_OpDoesNotSupportDispatch(self):
108    def some_op(x, y):
109      return x + y
110
111    with self.assertRaisesRegexp(AssertionError, "Dispatching not enabled for"):
112      @dispatch.dispatch_for_types(some_op, CustomTensor)
113      def override_for_some_op(x, y):  # pylint: disable=unused-variable
114        return x if x.score > 0 else y
115
116
117if __name__ == "__main__":
118  googletest.main()
119
120
121