1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unified callbacks op execution and creation under eager and graph modes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.eager import execute 23 24 25def add_op_callback(callback_fn): 26 r"""Add a thread-local callback that intercepts op execution and op creation. 27 28 The `callback_fn` will be invoked immediately after any of the three types 29 of events: 30 - The execution of an TensorFlow operation ("op" for short hereafter) 31 under eager mode, 32 - The execution of a FuncGraph under eager mode, 33 - The creation of an op during graph construction (e.g., in 34 @tf.function-decorated Python functions). 35 36 Known limitations: 37 1. Under graph mode, overriding the output tensors of control-flow ops, 38 including "If", "StatelessIf" and "While", may cause errors 39 (b/139668453). Overriding other tensors in a graph consisting of such 40 control-flow ops is okay. 41 2. Under eager mode, calling eager ops from the callback function itself 42 may lead to recursion stack overflow. This can be prevented by 43 returning from the callback function immediately on encountering the 44 op type involved (b/140334369). 45 46 Args: 47 callback_fn: A callback_fn that has the following signature: 48 def callback_fn(op_type, 49 inputs, 50 attrs, 51 outputs, 52 op_name=None, 53 graph=None): 54 # op_type: The type of the op, as a string. E.g., "MatMul". 55 # For the special case of FuncGraph execution, op_type 56 # takes the name of the graph name, e.g., 57 # "__inference_my_func_24". 58 # inputs: (`tuple` of `Tensor`s) Input tensors to the op or the 59 # FuncGraph. 60 # - In eager execution, these are `EagerTensor`s. 61 # - In graph construction, these are non-eager `Tensor`s 62 # that form the inputs to the just-created op. 63 # attrs: The attributes of the op or FuncGraph of which the execution 64 # or creation caused the current invocation of the callback. 65 # This is applicable to both eager- and graph-based execution, 66 # as well as graph construction. 67 # This is a tuple of alternating attribute keys and attribute 68 # values. E.g., `('adjoint_a', False, 'adjoint_b', False)`. 69 # outputs: (`tuple of `Tensor`s) Output tensors from the op or 70 # FuncGraph. 71 # In eager execution, these are `EagerTensor`s. 72 # In graph construction, these are non-eager `Tensor`s that 73 # are the outputs of the just-created op. 74 # op_name: Name of the op. 75 # - If the current invocation of the callback is due to the 76 # eager execution of an op or FuncGraph, this will be 77 # `None`, as op names are meaningless in eager execution. 78 # - In graph construction, this is the name of the op, e.g., 79 # "MatMul_2". 80 # graph: The graph that the op belongs to (if any). 81 # - In eager execution of an op or FuncGraph, this is `None`. 82 # - In graph construction, this is the op's enclosing graph 83 # as a `tf.Graph` object. 84 # 85 # Return values: 86 # This callback function is expected to return `None` or 87 # a `list` or `tuple` of `Tensor`s with its length matching 88 # `len(outputs)`, in the order that corresponds to that of the 89 # `outputs` argument. 90 # If the return value is `None`, downstream execution or graph 91 # construction will be unaffected. 92 # However, if the return value is a `list` or `tuple` of `Tensor`s, 93 # - In eager execution, these returned `Tensor`s should be 94 # `EagerTensor`s. Their values will replace the original values of 95 # `outputs` for downstream eager execution. (*Not implemented yet*). 96 # - In graph construction, these returned `Tensor`s should be 97 # non-eager `Tensor`s. Their values will replace the original 98 # `outputs` for downstream graph construction. 99 100 Raises: 101 ValueEror: If `callback_fn` is `None` or not callable. 102 """ 103 # TODO(b/139668041): Implement support for overriding `EagerTensor`s from 104 # callback. 105 if callback_fn is None: 106 raise ValueError("Passed callback function cannot be None.") 107 if not callable(callback_fn): 108 raise ValueError( 109 "Callback function passed to op_callback() is expected to be callable, " 110 "but is not. Received %s" % callback_fn) 111 ctx = context.context() 112 ctx.add_op_callback(callback_fn) 113 if ctx.executing_eagerly(): 114 # Monkey-patch `execute.execute()`. 115 execute.execute = execute.execute_with_callbacks 116 117 118def should_invoke_op_callbacks(): 119 """Determine if op callbacks are present and should be invoked. 120 121 Returns: 122 A thread-local result (boolean) indicating whether any op callback(s) exist 123 and should be invoked. 124 """ 125 ctx = context.context() 126 return ctx.op_callbacks and not ctx.invoking_op_callbacks 127 128 129def remove_op_callback(op_callback): 130 """Remove an already-added op callback. 131 132 Args: 133 op_callback: The op callback to be removed. 134 135 Raises: 136 KeyError: If `op_callback` has not been registered using `add_op_callback()` 137 before. 138 """ 139 ctx = context.context() 140 ctx.remove_op_callback(op_callback) 141 if ctx.executing_eagerly() and not ctx.op_callbacks: 142 # Undo monkey-patch of execute.execute if there are no more callbacks. 143 execute.execute = execute.quick_execute 144 145 146def clear_op_callbacks(): 147 """Clear all op callbacks registered in the current thread.""" 148 for callback in context.context().op_callbacks: 149 remove_op_callback(callback) 150 151 152def invoke_op_callbacks(op_type, 153 inputs, 154 attrs, 155 outputs, 156 op_name=None, 157 graph=None): 158 r"""Invoke the callbacks that exist in the current scope (if any). 159 160 If no callbacks are present in the current scope, this method returns 161 immediately. 162 163 Args: 164 op_type: Type of the operation (e.g., "MatMul"). 165 inputs: Input tensors to the op. These are `EagerTensor`s in the case of 166 eager execution of ops or `FuncGraph`s, and are non-eager `Tensor`s in the 167 case of graph construction. 168 attrs: Attributes of the op, as `tuple` of alternating keys and values. 169 outputs: Output tensors from the op. These are `EagerTensor`s in the case of 170 eager execution and are non-eager `Tensor`s in the case of graph 171 construction. 172 op_name: Name of the op. Applicable if and only if this method is invoked 173 due to the graph construction of an op or the eager execution of a 174 `FuncGraph`. 175 graph: The graph involved (if any). 176 - In the case if the eager execution of an op or FuncGraph, this is 177 `None`. 178 - In the case of the graph construction of an op, this is the `tf.Graph` 179 object being built. 180 181 Returns: 182 `None`, or a `list` or `tuple` of output tenors that will override the 183 original (input) `outputs`. 184 """ 185 ctx = context.context() 186 if ctx.op_callbacks: 187 # Guards against stack overflow that can result from recursive invocation 188 # due to op constructions inside client-supplied op callbacks. 189 ctx.invoking_op_callbacks = True 190 try: 191 if isinstance(attrs, dict): 192 attrs_list = [] 193 for key in attrs: 194 attrs_list.append(key) 195 attrs_list.append(attrs[key]) 196 attrs_tuple = tuple(attrs_list) 197 else: 198 attrs_tuple = attrs 199 200 new_outputs = outputs 201 for callback in ctx.op_callbacks: 202 new_outputs = callback( 203 op_type, 204 inputs, 205 attrs_tuple, 206 new_outputs, 207 op_name=op_name, 208 graph=graph) 209 if new_outputs is not None and len(new_outputs) != len(outputs): 210 raise ValueError( 211 "The op callback returned %s tensors, which does not match the " 212 "original number of outputs of op %s (%d)." % 213 (len(new_outputs), op_name, len(outputs))) 214 return new_outputs 215 finally: 216 ctx.invoking_op_callbacks = False 217 else: 218 return outputs 219