1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for operator dispatch.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import gen_math_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.platform import googletest 26from tensorflow.python.util import dispatch 27from tensorflow.python.util.tf_export import tf_export 28 29 30class CustomTensor(object): 31 """A fake composite tensor class, for testing type-based dispatching.""" 32 33 def __init__(self, tensor, score): 34 self.tensor = ops.convert_to_tensor(tensor) 35 self.score = score 36 37 38@tf_export("test_op") 39@dispatch.add_dispatch_support 40def test_op(x, y, z): 41 """A fake op for testing dispatch of Python ops.""" 42 return x + (2 * y) + (3 * z) 43 44 45@test_util.run_all_in_graph_and_eager_modes 46class DispatchTest(test_util.TensorFlowTestCase): 47 48 def testAddDispatchForTypes_With_CppOp(self): 49 original_handlers = gen_math_ops.add._tf_dispatchers[:] 50 51 # Override the behavior of gen_math_ops.add. 52 @dispatch.dispatch_for_types(gen_math_ops.add, CustomTensor) 53 def custom_add(x, y, name=None): # pylint: disable=unused-variable 54 return CustomTensor(gen_math_ops.add(x.tensor, y.tensor, name), 55 (x.score+y.score) / 2.0) 56 self.assertEqual(len(math_ops.add._tf_dispatchers), 57 len(original_handlers) + 1) 58 59 # Test that we see the overridden behavior when using CustomTensors. 60 x = CustomTensor([1, 2, 3], 2.0) 61 y = CustomTensor([7, 8, 2], 0.0) 62 x_plus_y = gen_math_ops.add(x, y) 63 self.assertAllEqual(self.evaluate(x_plus_y.tensor), [8, 10, 5]) 64 self.assertNear(x_plus_y.score, 1.0, 0.001) 65 66 # Test that we still get the right behavior when using normal Tensors. 67 a = [1, 2, 3] 68 b = [4, 5, 6] 69 a_plus_b = gen_math_ops.add(a, b) 70 self.assertAllEqual(a_plus_b, [5, 7, 9]) 71 72 # Test that we still get a TypeError or ValueError if we pass some 73 # type that's not supported by any dispatcher. 74 with self.assertRaises((TypeError, ValueError)): 75 gen_math_ops.add(a, None) 76 77 # Clean up 78 gen_math_ops.add._tf_dispatchers = original_handlers 79 80 def testAddDispatchForTypes_With_PythonOp(self): 81 original_handlers = test_op._tf_dispatchers[:] 82 83 @dispatch.dispatch_for_types(test_op, CustomTensor) 84 def override_for_test_op(x, y, z): # pylint: disable=unused-variable 85 return CustomTensor(test_op(x.tensor, y.tensor, z.tensor), 86 (x.score + y.score + z.score) / 3.0) 87 88 x = CustomTensor([1, 2, 3], 0.2) 89 y = CustomTensor([7, 8, 2], 0.4) 90 z = CustomTensor([0, 1, 2], 0.6) 91 92 result = test_op(x, y, z) 93 self.assertAllEqual(self.evaluate(result.tensor), [15, 21, 13]) 94 self.assertNear(result.score, 0.4, 0.001) 95 96 # Clean up 97 test_op._tf_dispatchers = original_handlers 98 99 def testDispatchForTypes_SignatureMismatch(self): 100 with self.assertRaisesRegexp(AssertionError, "The decorated function's " 101 "signature must exactly match.*"): 102 @dispatch.dispatch_for_types(test_op, CustomTensor) 103 def override_for_test_op(a, b, c): # pylint: disable=unused-variable 104 return CustomTensor(test_op(a.tensor, b.tensor, c.tensor), 105 (a.score + b.score + c.score) / 3.0) 106 107 def testDispatchForTypes_OpDoesNotSupportDispatch(self): 108 def some_op(x, y): 109 return x + y 110 111 with self.assertRaisesRegexp(AssertionError, "Dispatching not enabled for"): 112 @dispatch.dispatch_for_types(some_op, CustomTensor) 113 def override_for_some_op(x, y): # pylint: disable=unused-variable 114 return x if x.score > 0 else y 115 116 117if __name__ == "__main__": 118 googletest.main() 119 120 121