• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for operator dispatch."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_math_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.linalg import linear_operator_diag
28from tensorflow.python.ops.proto_ops import decode_proto
29from tensorflow.python.platform import googletest
30from tensorflow.python.platform import test
31from tensorflow.python.platform import tf_logging
32from tensorflow.python.util import deprecation
33from tensorflow.python.util import dispatch
34from tensorflow.python.util import nest
35from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
36from tensorflow.python.util.tf_export import tf_export
37
38
39class CustomTensor(object):
40  """A fake composite tensor class, for testing type-based dispatching."""
41
42  def __init__(self, tensor, score):
43    self.tensor = ops.convert_to_tensor(tensor)
44    self.score = score
45
46
47@tf_export("test_op")
48@dispatch.add_dispatch_support
49def test_op(x, y, z):
50  """A fake op for testing dispatch of Python ops."""
51  return x + (2 * y) + (3 * z)
52
53
54class TensorTracer(object):
55  """An object used to trace TensorFlow graphs.
56
57  This is an example class that is used to test global op dispatchers.  The
58  global op dispatcher for TensorTracers is defined below.
59  """
60
61  def __init__(self, name, args=None, kwargs=None):
62    self.name = name
63    self.args = args
64    self.kwargs = kwargs
65    self.shape = array_ops.ones(shape=(4, 4)).shape
66    self.dtype = dtypes.float32
67
68  def __repr__(self):
69    if self.args is None and self.kwargs is None:
70      return self.name
71    else:
72      args = [str(x) for x in self.args]
73      args += sorted(
74          ["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
75      return "{}({})".format(self.name, ", ".join(args))
76
77  @property
78  def is_tensor_like(self):
79    return True
80
81  @classmethod
82  def _overload_all_operators(cls):  # pylint: disable=invalid-name
83    """Register overloads for all operators."""
84    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
85      cls._overload_operator(operator)
86
87  @classmethod
88  def _overload_operator(cls, operator):  # pylint: disable=invalid-name
89    """Overload an operator with the same overloading as `ops.Tensor`."""
90    tensor_oper = getattr(ops.Tensor, operator)
91
92    # Compatibility with Python 2:
93    # Python 2 unbound methods have type checks for the first arg,
94    # so we need to extract the underlying function
95    tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
96    setattr(cls, operator, tensor_oper)
97
98TensorTracer._overload_all_operators()  # pylint: disable=protected-access
99
100
101class TensorTracerOpDispatcher(dispatch.GlobalOpDispatcher):
102  """Global op dispatcher for TensorTracer."""
103
104  def _flatten_with_slice_flattening(self, x):
105    flat = []
106    for val in nest.flatten(x):
107      if isinstance(val, slice):
108        flat.extend((val.start, val.stop, val.step))
109      else:
110        flat.append(val)
111    return flat
112
113  def handle(self, op, args, kwargs):
114    # Dispatcher only applies if at least one arg is a TensorTracer.
115    if not (any(self.is_tensor_tracer_arg(x) for x in args) or
116            any(self.is_tensor_tracer_arg(x) for x in kwargs.values())):
117      return self.NOT_SUPPORTED
118
119    symbol_name = get_canonical_name_for_symbol(op)
120    return TensorTracer(symbol_name, args, kwargs)
121
122  def is_tensor_tracer_arg(self, value):
123    return any(isinstance(x, TensorTracer) for x in
124               self._flatten_with_slice_flattening(value))
125
126
127@test_util.run_all_in_graph_and_eager_modes
128class DispatchTest(test_util.TensorFlowTestCase):
129
130  def testAddDispatchForTypes_With_CppOp(self):
131    original_handlers = gen_math_ops.add._tf_dispatchers[:]
132
133    # Override the behavior of gen_math_ops.add.
134    @dispatch.dispatch_for_types(gen_math_ops.add, CustomTensor)
135    def custom_add(x, y, name=None):  # pylint: disable=unused-variable
136      return CustomTensor(gen_math_ops.add(x.tensor, y.tensor, name),
137                          (x.score+y.score) / 2.0)
138    self.assertEqual(len(math_ops.add._tf_dispatchers),
139                     len(original_handlers) + 1)
140
141    # Test that we see the overridden behavior when using CustomTensors.
142    x = CustomTensor([1, 2, 3], 2.0)
143    y = CustomTensor([7, 8, 2], 0.0)
144    x_plus_y = gen_math_ops.add(x, y)
145    self.assertAllEqual(self.evaluate(x_plus_y.tensor), [8, 10, 5])
146    self.assertNear(x_plus_y.score, 1.0, 0.001)
147
148    # Test that we still get the right behavior when using normal Tensors.
149    a = [1, 2, 3]
150    b = [4, 5, 6]
151    a_plus_b = gen_math_ops.add(a, b)
152    self.assertAllEqual(a_plus_b, [5, 7, 9])
153
154    # Test that we still get a TypeError or ValueError if we pass some
155    # type that's not supported by any dispatcher.
156    with self.assertRaises((TypeError, ValueError)):
157      gen_math_ops.add(a, None)
158
159    # Clean up
160    gen_math_ops.add._tf_dispatchers = original_handlers
161
162  def testAddDispatchForTypes_With_PythonOp(self):
163    original_handlers = test_op._tf_dispatchers[:]
164
165    @dispatch.dispatch_for_types(test_op, CustomTensor)
166    def override_for_test_op(x, y, z):  # pylint: disable=unused-variable
167      return CustomTensor(test_op(x.tensor, y.tensor, z.tensor),
168                          (x.score + y.score + z.score) / 3.0)
169
170    x = CustomTensor([1, 2, 3], 0.2)
171    y = CustomTensor([7, 8, 2], 0.4)
172    z = CustomTensor([0, 1, 2], 0.6)
173
174    result = test_op(x, y, z)
175    self.assertAllEqual(self.evaluate(result.tensor), [15, 21, 13])
176    self.assertNear(result.score, 0.4, 0.001)
177
178    # Clean up
179    test_op._tf_dispatchers = original_handlers
180
181  def testDispatchForTypes_SignatureMismatch(self):
182    with self.assertRaisesRegex(
183        AssertionError, "The decorated function's "
184        "signature must exactly match.*"):
185
186      @dispatch.dispatch_for_types(test_op, CustomTensor)
187      def override_for_test_op(a, b, c):  # pylint: disable=unused-variable
188        return CustomTensor(test_op(a.tensor, b.tensor, c.tensor),
189                            (a.score + b.score + c.score) / 3.0)
190
191  def testDispatchForTypes_OpDoesNotSupportDispatch(self):
192    def some_op(x, y):
193      return x + y
194
195    with self.assertRaisesRegex(AssertionError, "Dispatching not enabled for"):
196
197      @dispatch.dispatch_for_types(some_op, CustomTensor)
198      def override_for_some_op(x, y):  # pylint: disable=unused-variable
199        return x if x.score > 0 else y
200
201  @test.mock.patch.object(tf_logging, "warning", autospec=True)
202  def testInteractionWithDeprecationWarning(self, mock_warning):
203    @deprecation.deprecated(date=None, instructions="Instructions")
204    @dispatch.add_dispatch_support
205    def some_op(x):
206      return x
207
208    some_op(5)
209
210    message = mock_warning.call_args[0][0] % mock_warning.call_args[0][1:]
211    self.assertRegex(
212        message, r".*some_op \(from __main__\) is deprecated and will be "
213        "removed in a future version.*")
214
215  def testGlobalDispatcher(self):
216    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
217    try:
218      TensorTracerOpDispatcher().register()
219
220      x = TensorTracer("x")
221      y = TensorTracer("y")
222      trace = math_ops.reduce_sum(math_ops.add(math_ops.abs(x), y), axis=3)
223      self.assertEqual(
224          str(trace),
225          "math.reduce_sum(math.add(name=None, x=math.abs(x), y=y), axis=3)")
226
227      proto_val = TensorTracer("proto")
228      trace = decode_proto(proto_val, "message_type", ["field"], ["float32"])
229      self.assertIn("io.decode_proto(bytes=proto,", str(trace))
230
231    finally:
232      # Clean up.
233      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
234
235  def testGlobalDispatcherConvertToTensor(self):
236    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
237    try:
238      TensorTracerOpDispatcher().register()
239
240      x = TensorTracer("x")
241      y = TensorTracer("y")
242      trace = math_ops.add(math_ops.abs(
243          ops.convert_to_tensor_v2_with_dispatch(x)), y)
244      self.assertEqual(
245          str(trace),
246          "math.add(name=None, x=math.abs(convert_to_tensor(x)), y=y)")
247
248    finally:
249      # Clean up.
250      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
251
252  def testGlobalDispatcherGetItem(self):
253    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
254    try:
255      TensorTracerOpDispatcher().register()
256
257      x = TensorTracer("x")
258      trace = x[0]
259      self.assertEqual(
260          str(trace),
261          "__operators__.getitem(x, 0)")
262
263      x = TensorTracer("x")
264      y = TensorTracer("y")
265      trace = x[y]
266      self.assertEqual(
267          str(trace),
268          "__operators__.getitem(x, y)")
269
270      x = TensorTracer("x")
271      y = TensorTracer("y")
272      trace = x[:y]  # pylint: disable=invalid-slice-index
273      self.assertEqual(
274          str(trace),
275          "__operators__.getitem(x, slice(None, y, None))")
276
277      x = array_ops.ones(shape=(3, 3))
278      y = TensorTracer("y")
279      trace = x[y]
280      self.assertEqual(
281          str(trace),
282          "__operators__.getitem(%s, y)" % x)
283
284      trace = x[:y]  # pylint: disable=invalid-slice-index
285      self.assertEqual(
286          str(trace),
287          "__operators__.getitem(%s, slice(None, y, None))" % x)
288
289    finally:
290      # Clean up.
291      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
292
293  def testGlobalDispatcherLinearOperators(self):
294    original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
295    try:
296      TensorTracerOpDispatcher().register()
297
298      x = TensorTracer("x")
299
300      # To grab the eigenvalues the diag operator just calls convert_to_tensor
301      # (twice) in this case.
302      trace = linear_operator_diag.LinearOperatorDiag(x).eigvals()
303      self.assertEqual(
304          str(trace),
305          "convert_to_tensor(convert_to_tensor(x, dtype=None, dtype_hint=None, "
306          "name=diag))")
307
308      # The diagonal tensor addition gets traced even though the linear_operator
309      # API only uses dispatchable ops instead of directly exposing dispatching.
310      trace = linear_operator_diag.LinearOperatorDiag(x).add_to_tensor(x)
311      self.assertIn(
312          "linalg.set_diag(convert_to_tensor(x, name=x), __operators__.add("
313          "convert_to_tensor(x, dtype=None, dtype_hint=None, name=diag), "
314          "linalg.diag_part(convert_to_tensor(x, name=x)), "
315          "name=",
316          str(trace))
317
318      # The dispatch-supporting ops the non-singular check calls out to
319      # get traced.
320      trace = linear_operator_diag.LinearOperatorDiag(x).assert_non_singular()
321      self.assertIn("debugging.assert_less", str(trace))
322      self.assertIn(
323          "message=Singular operator:  Diagonal contained zero values.",
324          str(trace))
325
326    finally:
327      # Clean up.
328      dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
329
330if __name__ == "__main__":
331  googletest.main()
332