• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unified callbacks op execution and creation under eager and graph modes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.eager import execute
23
24
25def add_op_callback(callback_fn):
26  r"""Add a thread-local callback that intercepts op execution and op creation.
27
28  The `callback_fn` will be invoked immediately after any of the three types
29  of events:
30    - The execution of an TensorFlow operation ("op" for short hereafter)
31      under eager mode,
32    - The execution of a FuncGraph under eager mode,
33    - The creation of an op during graph construction (e.g., in
34      @tf.function-decorated Python functions).
35
36  Known limitations:
37    1. Under graph mode, overriding the output tensors of control-flow ops,
38       including "If", "StatelessIf" and "While", may cause errors
39       (b/139668453). Overriding other tensors in a graph consisting of such
40       control-flow ops is okay.
41    2. Under eager mode, calling eager ops from the callback function itself
42       may lead to recursion stack overflow. This can be prevented by
43       returning from the callback function immediately on encountering the
44       op type involved (b/140334369).
45
46  Args:
47    callback_fn: A callback_fn that has the following signature:
48      def callback_fn(op_type,
49                      inputs,
50                      attrs,
51                      outputs,
52                      op_name=None,
53                      graph=None):
54        # op_type: The type of the op, as a string. E.g., "MatMul".
55        #          For the special case of FuncGraph execution, op_type
56        #          takes the name of the graph name, e.g.,
57        #          "__inference_my_func_24".
58        # inputs: (`tuple` of `Tensor`s) Input tensors to the op or the
59        #         FuncGraph.
60        #         - In eager execution, these are `EagerTensor`s.
61        #         - In graph construction, these are non-eager `Tensor`s
62        #           that form the inputs to the just-created op.
63        # attrs: The attributes of the op or FuncGraph of which the execution
64        #        or creation caused the current invocation of the callback.
65        #        This is applicable to both eager- and graph-based execution,
66        #        as well as graph construction.
67        #        This is a tuple of alternating attribute keys and attribute
68        #        values. E.g., `('adjoint_a', False, 'adjoint_b', False)`.
69        # outputs: (`tuple of `Tensor`s) Output tensors from the op or
70        #          FuncGraph.
71        #          In eager execution, these are `EagerTensor`s.
72        #          In graph construction, these are non-eager `Tensor`s that
73        #          are the outputs of the just-created op.
74        # op_name: Name of the op.
75        #          - If the current invocation of the callback is due to the
76        #            eager execution of an op or FuncGraph, this will be
77        #            `None`, as op names are meaningless in eager execution.
78        #          - In graph construction, this is the name of the op, e.g.,
79        #            "MatMul_2".
80        # graph: The graph that the op belongs to (if any).
81        #        - In eager execution of an op or FuncGraph, this is `None`.
82        #        - In graph construction, this is the op's enclosing graph
83        #          as a `tf.Graph` object.
84        #
85        # Return values:
86        #   This callback function is expected to return `None` or
87        #   a `list` or `tuple` of `Tensor`s with its length matching
88        #   `len(outputs)`, in the order that corresponds to that of the
89        #   `outputs` argument.
90        #   If the return value is `None`, downstream execution or graph
91        #   construction will be unaffected.
92        #   However, if the return value is a `list` or `tuple` of `Tensor`s,
93        #   - In eager execution, these returned `Tensor`s should be
94        #     `EagerTensor`s. Their values will replace the original values of
95        #     `outputs` for downstream eager execution. (*Not implemented yet*).
96        #   - In graph construction, these returned `Tensor`s should be
97        #     non-eager `Tensor`s. Their values will replace the original
98        #     `outputs` for downstream graph construction.
99
100  Raises:
101    ValueEror: If `callback_fn` is `None` or not callable.
102  """
103  # TODO(b/139668041): Implement support for overriding `EagerTensor`s from
104  # callback.
105  if callback_fn is None:
106    raise ValueError("Passed callback function cannot be None.")
107  if not callable(callback_fn):
108    raise ValueError(
109        "Callback function passed to op_callback() is expected to be callable, "
110        "but is not. Received %s" % callback_fn)
111  ctx = context.context()
112  ctx.add_op_callback(callback_fn)
113  if ctx.executing_eagerly():
114    # Monkey-patch `execute.execute()`.
115    execute.execute = execute.execute_with_callbacks
116
117
118def should_invoke_op_callbacks():
119  """Determine if op callbacks are present and should be invoked.
120
121  Returns:
122    A thread-local result (boolean) indicating whether any op callback(s) exist
123    and should be invoked.
124  """
125  ctx = context.context()
126  return ctx.op_callbacks and not ctx.invoking_op_callbacks
127
128
129def remove_op_callback(op_callback):
130  """Remove an already-added op callback.
131
132  Args:
133    op_callback: The op callback to be removed.
134
135  Raises:
136    KeyError: If `op_callback` has not been registered using `add_op_callback()`
137      before.
138  """
139  ctx = context.context()
140  ctx.remove_op_callback(op_callback)
141  if ctx.executing_eagerly() and not ctx.op_callbacks:
142    # Undo monkey-patch of execute.execute if there are no more callbacks.
143    execute.execute = execute.quick_execute
144
145
146def clear_op_callbacks():
147  """Clear all op callbacks registered in the current thread."""
148  for callback in context.context().op_callbacks:
149    remove_op_callback(callback)
150
151
152def invoke_op_callbacks(op_type,
153                        inputs,
154                        attrs,
155                        outputs,
156                        op_name=None,
157                        graph=None):
158  r"""Invoke the callbacks that exist in the current scope (if any).
159
160  If no callbacks are present in the current scope, this method returns
161  immediately.
162
163  Args:
164    op_type: Type of the operation (e.g., "MatMul").
165    inputs: Input tensors to the op. These are `EagerTensor`s in the case of
166      eager execution of ops or `FuncGraph`s, and are non-eager `Tensor`s in the
167      case of graph construction.
168    attrs: Attributes of the op, as `tuple` of alternating keys and values.
169    outputs: Output tensors from the op. These are `EagerTensor`s in the case of
170      eager execution and are non-eager `Tensor`s in the case of graph
171      construction.
172    op_name: Name of the op. Applicable if and only if this method is invoked
173      due to the graph construction of an op or the eager execution of a
174      `FuncGraph`.
175    graph: The graph involved (if any).
176      - In the case if the eager execution of an op or FuncGraph, this is
177        `None`.
178      - In the case of the graph construction of an op, this is the `tf.Graph`
179        object being built.
180
181  Returns:
182    `None`, or a `list` or `tuple` of output tenors that will override the
183    original (input) `outputs`.
184  """
185  ctx = context.context()
186  if ctx.op_callbacks:
187    # Guards against stack overflow that can result from recursive invocation
188    # due to op constructions inside client-supplied op callbacks.
189    ctx.invoking_op_callbacks = True
190    try:
191      if isinstance(attrs, dict):
192        attrs_list = []
193        for key in attrs:
194          attrs_list.append(key)
195          attrs_list.append(attrs[key])
196        attrs_tuple = tuple(attrs_list)
197      else:
198        attrs_tuple = attrs
199
200      new_outputs = outputs
201      for callback in ctx.op_callbacks:
202        new_outputs = callback(
203            op_type,
204            inputs,
205            attrs_tuple,
206            new_outputs,
207            op_name=op_name,
208            graph=graph)
209        if new_outputs is not None and len(new_outputs) != len(outputs):
210          raise ValueError(
211              "The op callback returned %s tensors, which does not match the "
212              "original number of outputs of op %s (%d)." %
213              (len(new_outputs), op_name, len(outputs)))
214      return new_outputs
215    finally:
216      ctx.invoking_op_callbacks = False
217  else:
218    return outputs
219