1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for operator dispatch.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_math_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.linalg import linear_operator_diag 28from tensorflow.python.ops.proto_ops import decode_proto 29from tensorflow.python.platform import googletest 30from tensorflow.python.platform import test 31from tensorflow.python.platform import tf_logging 32from tensorflow.python.util import deprecation 33from tensorflow.python.util import dispatch 34from tensorflow.python.util import nest 35from tensorflow.python.util.tf_export import get_canonical_name_for_symbol 36from tensorflow.python.util.tf_export import tf_export 37 38 39class CustomTensor(object): 40 """A fake composite tensor class, for testing type-based dispatching.""" 41 42 def __init__(self, tensor, score): 43 self.tensor = ops.convert_to_tensor(tensor) 44 self.score = score 45 46 47@tf_export("test_op") 48@dispatch.add_dispatch_support 49def test_op(x, y, z): 50 """A fake op for testing dispatch of Python ops.""" 51 return x + (2 * y) + (3 * z) 52 53 54class TensorTracer(object): 55 """An object used to trace TensorFlow graphs. 56 57 This is an example class that is used to test global op dispatchers. The 58 global op dispatcher for TensorTracers is defined below. 59 """ 60 61 def __init__(self, name, args=None, kwargs=None): 62 self.name = name 63 self.args = args 64 self.kwargs = kwargs 65 self.shape = array_ops.ones(shape=(4, 4)).shape 66 self.dtype = dtypes.float32 67 68 def __repr__(self): 69 if self.args is None and self.kwargs is None: 70 return self.name 71 else: 72 args = [str(x) for x in self.args] 73 args += sorted( 74 ["{}={}".format(name, x) for (name, x) in self.kwargs.items()]) 75 return "{}({})".format(self.name, ", ".join(args)) 76 77 @property 78 def is_tensor_like(self): 79 return True 80 81 @classmethod 82 def _overload_all_operators(cls): # pylint: disable=invalid-name 83 """Register overloads for all operators.""" 84 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 85 cls._overload_operator(operator) 86 87 @classmethod 88 def _overload_operator(cls, operator): # pylint: disable=invalid-name 89 """Overload an operator with the same overloading as `ops.Tensor`.""" 90 tensor_oper = getattr(ops.Tensor, operator) 91 92 # Compatibility with Python 2: 93 # Python 2 unbound methods have type checks for the first arg, 94 # so we need to extract the underlying function 95 tensor_oper = getattr(tensor_oper, "__func__", tensor_oper) 96 setattr(cls, operator, tensor_oper) 97 98TensorTracer._overload_all_operators() # pylint: disable=protected-access 99 100 101class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher): 102 """Global op dispatcher for TensorTracer.""" 103 104 def _flatten_with_slice_flattening(self, x): 105 flat = [] 106 for val in nest.flatten(x): 107 if isinstance(val, slice): 108 flat.extend((val.start, val.stop, val.step)) 109 else: 110 flat.append(val) 111 return flat 112 113 def handle(self, op, args, kwargs): 114 # Dispatcher only applies if at least one arg is a TensorTracer. 115 if not (any(self.is_tensor_tracer_arg(x) for x in args) or 116 any(self.is_tensor_tracer_arg(x) for x in kwargs.values())): 117 return self.NOT_SUPPORTED 118 119 symbol_name = get_canonical_name_for_symbol(op) 120 return TensorTracer(symbol_name, args, kwargs) 121 122 def is_tensor_tracer_arg(self, value): 123 return any(isinstance(x, TensorTracer) for x in 124 self._flatten_with_slice_flattening(value)) 125 126 127@test_util.run_all_in_graph_and_eager_modes 128class DispatchTest(test_util.TensorFlowTestCase): 129 130 def testAddDispatchForTypes_With_CppOp(self): 131 original_handlers = gen_math_ops.add._tf_dispatchers[:] 132 133 # Override the behavior of gen_math_ops.add. 134 @dispatch.dispatch_for_types(gen_math_ops.add, CustomTensor) 135 def custom_add(x, y, name=None): # pylint: disable=unused-variable 136 return CustomTensor(gen_math_ops.add(x.tensor, y.tensor, name), 137 (x.score+y.score) / 2.0) 138 self.assertEqual(len(math_ops.add._tf_dispatchers), 139 len(original_handlers) + 1) 140 141 # Test that we see the overridden behavior when using CustomTensors. 142 x = CustomTensor([1, 2, 3], 2.0) 143 y = CustomTensor([7, 8, 2], 0.0) 144 x_plus_y = gen_math_ops.add(x, y) 145 self.assertAllEqual(self.evaluate(x_plus_y.tensor), [8, 10, 5]) 146 self.assertNear(x_plus_y.score, 1.0, 0.001) 147 148 # Test that we still get the right behavior when using normal Tensors. 149 a = [1, 2, 3] 150 b = [4, 5, 6] 151 a_plus_b = gen_math_ops.add(a, b) 152 self.assertAllEqual(a_plus_b, [5, 7, 9]) 153 154 # Test that we still get a TypeError or ValueError if we pass some 155 # type that's not supported by any dispatcher. 156 with self.assertRaises((TypeError, ValueError)): 157 gen_math_ops.add(a, None) 158 159 # Clean up 160 gen_math_ops.add._tf_dispatchers = original_handlers 161 162 def testAddDispatchForTypes_With_PythonOp(self): 163 original_handlers = test_op._tf_dispatchers[:] 164 165 @dispatch.dispatch_for_types(test_op, CustomTensor) 166 def override_for_test_op(x, y, z): # pylint: disable=unused-variable 167 return CustomTensor(test_op(x.tensor, y.tensor, z.tensor), 168 (x.score + y.score + z.score) / 3.0) 169 170 x = CustomTensor([1, 2, 3], 0.2) 171 y = CustomTensor([7, 8, 2], 0.4) 172 z = CustomTensor([0, 1, 2], 0.6) 173 174 result = test_op(x, y, z) 175 self.assertAllEqual(self.evaluate(result.tensor), [15, 21, 13]) 176 self.assertNear(result.score, 0.4, 0.001) 177 178 # Clean up 179 test_op._tf_dispatchers = original_handlers 180 181 def testDispatchForTypes_SignatureMismatch(self): 182 with self.assertRaisesRegex( 183 AssertionError, "The decorated function's " 184 "signature must exactly match.*"): 185 186 @dispatch.dispatch_for_types(test_op, CustomTensor) 187 def override_for_test_op(a, b, c): # pylint: disable=unused-variable 188 return CustomTensor(test_op(a.tensor, b.tensor, c.tensor), 189 (a.score + b.score + c.score) / 3.0) 190 191 def testDispatchForTypes_OpDoesNotSupportDispatch(self): 192 def some_op(x, y): 193 return x + y 194 195 with self.assertRaisesRegex(AssertionError, "Dispatching not enabled for"): 196 197 @dispatch.dispatch_for_types(some_op, CustomTensor) 198 def override_for_some_op(x, y): # pylint: disable=unused-variable 199 return x if x.score > 0 else y 200 201 @test.mock.patch.object(tf_logging, "warning", autospec=True) 202 def testInteractionWithDeprecationWarning(self, mock_warning): 203 @deprecation.deprecated(date=None, instructions="Instructions") 204 @dispatch.add_dispatch_support 205 def some_op(x): 206 return x 207 208 some_op(5) 209 210 message = mock_warning.call_args[0][0] % mock_warning.call_args[0][1:] 211 self.assertRegex( 212 message, r".*some_op \(from __main__\) is deprecated and will be " 213 "removed in a future version.*") 214 215 def testGlobalDispatcher(self): 216 original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS 217 try: 218 TensorTracerOpDispatcher().register() 219 220 x = TensorTracer("x") 221 y = TensorTracer("y") 222 trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3) 223 self.assertEqual( 224 str(trace), 225 "math.reduce_sum(math.add(name=None, x=math.abs(x), y=y), axis=3)") 226 227 proto_val = TensorTracer("proto") 228 trace = decode_proto(proto_val, "message_type", ["field"], ["float32"]) 229 self.assertIn("io.decode_proto(bytes=proto,", str(trace)) 230 231 finally: 232 # Clean up. 233 dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers 234 235 def testGlobalDispatcherConvertToTensor(self): 236 original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS 237 try: 238 TensorTracerOpDispatcher().register() 239 240 x = TensorTracer("x") 241 y = TensorTracer("y") 242 trace = math_ops.add(math_ops.abs( 243 ops.convert_to_tensor_v2_with_dispatch(x)), y) 244 self.assertEqual( 245 str(trace), 246 "math.add(name=None, x=math.abs(convert_to_tensor(x)), y=y)") 247 248 finally: 249 # Clean up. 250 dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers 251 252 def testGlobalDispatcherGetItem(self): 253 original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS 254 try: 255 TensorTracerOpDispatcher().register() 256 257 x = TensorTracer("x") 258 trace = x[0] 259 self.assertEqual( 260 str(trace), 261 "__operators__.getitem(x, 0)") 262 263 x = TensorTracer("x") 264 y = TensorTracer("y") 265 trace = x[y] 266 self.assertEqual( 267 str(trace), 268 "__operators__.getitem(x, y)") 269 270 x = TensorTracer("x") 271 y = TensorTracer("y") 272 trace = x[:y] # pylint: disable=invalid-slice-index 273 self.assertEqual( 274 str(trace), 275 "__operators__.getitem(x, slice(None, y, None))") 276 277 x = array_ops.ones(shape=(3, 3)) 278 y = TensorTracer("y") 279 trace = x[y] 280 self.assertEqual( 281 str(trace), 282 "__operators__.getitem(%s, y)" % x) 283 284 trace = x[:y] # pylint: disable=invalid-slice-index 285 self.assertEqual( 286 str(trace), 287 "__operators__.getitem(%s, slice(None, y, None))" % x) 288 289 finally: 290 # Clean up. 291 dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers 292 293 def testGlobalDispatcherLinearOperators(self): 294 original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS 295 try: 296 TensorTracerOpDispatcher().register() 297 298 x = TensorTracer("x") 299 300 # To grab the eigenvalues the diag operator just calls convert_to_tensor 301 # (twice) in this case. 302 trace = linear_operator_diag.LinearOperatorDiag(x).eigvals() 303 self.assertEqual( 304 str(trace), 305 "convert_to_tensor(convert_to_tensor(x, dtype=None, dtype_hint=None, " 306 "name=diag))") 307 308 # The diagonal tensor addition gets traced even though the linear_operator 309 # API only uses dispatchable ops instead of directly exposing dispatching. 310 trace = linear_operator_diag.LinearOperatorDiag(x).add_to_tensor(x) 311 self.assertIn( 312 "linalg.set_diag(convert_to_tensor(x, name=x), __operators__.add(" 313 "convert_to_tensor(x, dtype=None, dtype_hint=None, name=diag), " 314 "linalg.diag_part(convert_to_tensor(x, name=x)), " 315 "name=", 316 str(trace)) 317 318 # The dispatch-supporting ops the non-singular check calls out to 319 # get traced. 320 trace = linear_operator_diag.LinearOperatorDiag(x).assert_non_singular() 321 self.assertIn("debugging.assert_less", str(trace)) 322 self.assertIn( 323 "message=Singular operator: Diagonal contained zero values.", 324 str(trace)) 325 326 finally: 327 # Clean up. 328 dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers 329 330if __name__ == "__main__": 331 googletest.main() 332