• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Script Language Operators."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import threading
23
24# Used by py_util.cc to get tracebacks.
25import traceback  # pylint: disable=unused-import
26import weakref
27
28import numpy as np
29import six
30
31from tensorflow.python import pywrap_tensorflow
32from tensorflow.python.eager import backprop
33from tensorflow.python.eager import context
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import function
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import gen_script_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.util import compat
41from tensorflow.python.util import deprecation
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44
45# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
46# used for differentiation.
47tape_cache = {}
48
49
50class EagerFunc(object):
51  """A wrapper for a function owned by an EagerPyFunc."""
52
53  def __init__(self, func, Tout, is_grad_func):
54    """Constructs an EagerFunc.
55
56    Args:
57      func: The function to wrap.
58      Tout: A list of datatypes for the output; an empty list if the output is
59            None.
60      is_grad_func: Whether this EagerFunc is the gradient of another
61        EagerPyFunc.
62    """
63    self._func = func
64    self._out_dtypes = Tout
65    self._is_grad_func = is_grad_func
66
67  def _convert(self, value, dtype):
68    """Converts `value` to a tensor of type `dtype`, with error checking.
69
70    Args:
71      value: The tensor to convert.
72      dtype: The desired dtype.
73
74    Returns:
75      A tensor of type `dtype`, or a zeros tensor if value is None and
76      this function is in fact a grdient function.
77
78    Raises:
79      RuntimeError: if `value` is a variable.
80    """
81
82    if isinstance(value, resource_variable_ops.ResourceVariable):
83      raise RuntimeError(
84          "Attempting to return a variable from an eagerly executed py_func. "
85          "Only numeric data structures like Tensors or NumPy arrays should "
86          "be returned; to return the value of a variable, make sure to obtain "
87          "the Tensor backing it by calling `.read_value()` on the variable in "
88          "question: %s" % value)
89    if value is None and self._is_grad_func:
90      # Gradient functions may legitimately return a list that contains
91      # both Tensors and Python Nones. Unfortuantely this breaks the
92      # OpKernel, so for now we replace None objects with zeros, which is
93      # mathematically correct but will prevent short-circuiting gradient
94      # computations.
95      #
96      # TODO(akshayka): Make it possible to return a list of both Tensors and
97      # Nones from an EagerPyFunc.
98      return constant_op.constant(0.0, dtype=dtype)
99    return ops.convert_to_tensor(value, dtype=dtype)
100
101  def __call__(self, device, token, args):
102    """Passes `args` to `self._func`, which is executed eagerly."""
103
104    with context.eager_mode(), backprop.GradientTape() as tape:
105      for tensor in args:
106        tape.watch(tensor)
107      ret = self._func(*args)
108      # Use tf.identity to copy the returned tensors to device if neccesary.
109      with ops.device(device):
110        if isinstance(ret, (tuple, list)):
111          outputs = [
112              array_ops.identity(self._convert(x, dtype=dtype))
113              for (x, dtype) in zip(ret, self._out_dtypes)
114          ]
115        elif ret is None:
116          outputs = None
117        else:
118          outputs = array_ops.identity(
119              self._convert(ret, dtype=self._out_dtypes[0]))
120    tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
121    return outputs
122
123
124class FuncRegistry(object):
125  """A helper class to keep track of registered py functions.
126
127  FuncRegistry keeps a map from unique tokens (string) to python
128  functions, which takes numpy arrays and outputs numpy arrays.
129  """
130
131  def __init__(self):
132    self._lock = threading.Lock()
133    self._unique_id = 0  # GUARDED_BY(self._lock)
134    # Only store weakrefs to the functions. The strong reference is stored in
135    # the graph.
136    self._funcs = weakref.WeakValueDictionary()
137
138  def insert(self, func):
139    """Registers `func` and returns a unique token for this entry."""
140    token = self._next_unique_token()
141    # Store a weakref to the function
142    self._funcs[token] = func
143    return token
144
145  def remove(self, token):
146    """Removes the registered function corresponding to `token`."""
147    self._funcs.pop(token, None)
148
149  @staticmethod
150  def _convert(value, dtype=None):
151    """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
152
153    Numpy pads with zeros when using string and unicode dtypes if different
154    components of a tensor have different lengths.  This is bad: ignoring the
155    padding is wrong for text data, and removing the padding is wrong for binary
156    data.  To avoid this bug, we redo the conversion using an object dtype.
157    Additionally, we convert unicode strings to (byte-)strings for
158    compatibility.
159
160    Args:
161      value: Value to convert to a numpy array.
162      dtype: (Optional.) Desired NumPy type for the returned value.
163
164    Returns:
165      A numpy array.
166    """
167    result = np.asarray(value, dtype=dtype, order="C")
168    if result.dtype.char == "S" and result is not value:
169      return np.asarray(value, order="C", dtype=object)
170    elif result.dtype.char == "U" and result is not value:
171      value = np.vectorize(lambda x: x.encode("utf8"))(value)
172      return np.asarray(value, order="C", dtype=object)
173    elif result.dtype.char == "U":
174      return result.astype(np.bytes_)
175    else:
176      return result
177
178  def __call__(self, token, device, args):
179    """Calls the registered function for `token` with args.
180
181    Args:
182      token: A key into this `FuncRegistry` identifying which function to call.
183      device: Name of the device on which outputs of `token`'s corresponding
184        operation should be placed. Used iff the function registered for `token`
185        is an EagerPyFunc.
186      args: The arguments to pass to the function registered for `token`.
187
188    Returns:
189      The output of the function registered for `token`.
190
191    Raises:
192      ValueError: if no function is registered for `token`.
193    """
194    func = self._funcs.get(token, None)
195    if func is None:
196      raise ValueError("callback %s is not found" % token)
197    if isinstance(func, EagerFunc):
198      # NB: Different invocations of the same py_func will share the same
199      # token, and the entries they stash in the tape_cache will collide.
200      # In practice, when executing a graph, this should only happen if
201      # the py_func is in a while_loop whose iterations are run in parallel
202      # or if the graph is being driven by concurrent session.run() calls.
203      #
204      # TODO(akshayka): Key the tape cache in a thread-safe way.
205      return func(device, token, args)
206    else:
207      ret = func(*args)
208      # Strings seem to lead to a memory leak here if they're not wrapped in a
209      # list.
210      if isinstance(ret, six.binary_type):
211        ret = [ret]
212      # Ensures that we return either a single numpy array or a list of numpy
213      # arrays.
214      if isinstance(ret, (tuple, list)):
215        return [self._convert(x) for x in ret]
216      else:
217        return self._convert(ret)
218
219  def size(self):
220    """Returns how many functions are currently registered."""
221    return len(self._funcs)
222
223  def _next_unique_token(self):
224    """Returns a unique token."""
225    with self._lock:
226      uid = self._unique_id
227      self._unique_id += 1
228    return "pyfunc_%d" % uid
229
230# Global registry for py functions.
231_py_funcs = FuncRegistry()
232
233pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
234
235
236def _internal_py_func(func,
237                      inp,
238                      Tout,
239                      stateful=None,
240                      eager=False,
241                      is_grad_func=False,
242                      name=None):
243  """See documentation for py_func and eager_py_func."""
244
245  is_list_or_tuple = False
246  if isinstance(Tout, (list, tuple)):
247    is_list_or_tuple = True
248  else:
249    Tout = [Tout]
250
251  if eager:
252    func = EagerFunc(func, Tout, is_grad_func)
253
254  token = _py_funcs.insert(func)
255  # We tie the registered function's lifetime with the current default graph,
256  # i.e., when the current graph is destroyed, we remove its py funcs.
257  graph = ops.get_default_graph()
258
259  # pylint: disable=protected-access
260  while isinstance(graph, function._FuncGraph):
261    # If the py_func was declared inside a _FuncGraph, its lifetime should be
262    # bound to that of the outer graph instead.
263    graph = graph._outer_graph
264
265  # TODO(zhifengc): Consider adding a Graph method to collect
266  # `cleanup` objects in one of its member.
267  if not hasattr(graph, "_py_funcs_used_in_graph"):
268    graph._py_funcs_used_in_graph = []
269
270  # Store a reference to the function in the graph to ensure it stays alive
271  # as long as the graph lives. When the graph is destroyed, the function
272  # is left to the garbage collector for destruction as well.
273  graph._py_funcs_used_in_graph.append(func)
274  # pylint: enable=protected-access
275
276  if eager:
277    result = gen_script_ops.eager_py_func(
278        input=inp, token=token, Tout=Tout, name=name)
279  else:
280    if stateful:
281      result = gen_script_ops.py_func(
282          input=inp, token=token, Tout=Tout, name=name)
283    else:
284      result = gen_script_ops.py_func_stateless(
285          input=inp, token=token, Tout=Tout, name=name)
286  return result if is_list_or_tuple else result[0]
287
288
289# TODO(akshayka): Implement higher-order derivatives.
290@ops.RegisterGradient("EagerPyFunc")
291def _EagerPyFuncGrad(op, *dy):
292  """Computes the gradient of an EagerPyFunc."""
293
294  token = op.get_attr("token")
295
296  def eagerly_executed_grad(*dy):
297    tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
298    return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
299
300  with ops.control_dependencies(op.outputs):
301    return _internal_py_func(
302        func=eagerly_executed_grad,
303        inp=dy,
304        Tout=[tensor.dtype for tensor in op.inputs],
305        eager=True,
306        is_grad_func=True)
307
308
309@tf_export("py_function")
310def eager_py_func(func, inp, Tout, name=None):
311  """Wraps a python function into a TensorFlow op that executes it eagerly.
312
313  This function allows expressing computations in a TensorFlow graph as
314  Python functions. In particular, it wraps a Python function `func`
315  in a once-differentiable TensorFlow operation that executes it with eager
316  execution enabled. As a consequence, `tf.contrib.eager.py_func` makes it
317  possible to express control flow using Python constructs (`if`, `while`,
318  `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`,
319  `tf.while_loop`). For example, you might use `tf.contrib.eager.py_func` to
320  implement the log huber function:
321
322  ```python
323  def log_huber(x, m):
324    if tf.abs(x) <= m:
325      return x**2
326    else:
327      return m**2 * (1 - 2 * tf.log(m) + tf.log(x**2))
328
329  x = tf.placeholder(tf.float32)
330  m = tf.placeholder(tf.float32)
331
332  y = tf.contrib.eager.py_func(func=log_huber, inp=[x, m], Tout=tf.float32)
333  dy_dx = tf.gradients(y, x)[0]
334
335  with tf.Session() as sess:
336    # The session executes `log_huber` eagerly. Given the feed values below,
337    # it will take the first branch, so `y` evaluates to 1.0 and
338    # `dy_dx` evaluates to 2.0.
339    y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
340  ```
341
342  You can also use `tf.contrib.eager.py_func` to debug your models at runtime
343  using Python tools, i.e., you can isolate portions of your code that
344  you want to debug, wrap them in Python functions and insert `pdb` tracepoints
345  or print statements as desired, and wrap those functions in
346  `tf.contrib.eager.py_func`.
347
348  For more information on eager execution, see the
349  [Eager guide](https://tensorflow.org/guide/eager).
350
351  `tf.contrib.eager.py_func` is similar in spirit to `tf.py_func`, but unlike
352  the latter, the former lets you use TensorFlow operations in the wrapped
353  Python function. In particular, while `tf.py_func` only runs on CPUs and
354  wraps functions that take NumPy arrays as inputs and return NumPy arrays as
355  outputs, `tf.contrib.eager.py_func` can be placed on GPUs and wraps functions
356  that take Tensors as inputs, execute TensorFlow operations in their bodies,
357  and return Tensors as outputs.
358
359  Like `tf.py_func`, `tf.contrib.eager.py_func` has the following limitations
360  with respect to serialization and distribution:
361
362  * The body of the function (i.e. `func`) will not be serialized in a
363    `GraphDef`. Therefore, you should not use this function if you need to
364    serialize your model and restore it in a different environment.
365
366  * The operation must run in the same address space as the Python program
367    that calls `tf.contrib.eager.py_func()`. If you are using distributed
368    TensorFlow, you must run a `tf.train.Server` in the same process as the
369    program that calls `tf.contrib.eager.py_func()` and you must pin the created
370    operation to a device in that server (e.g. using `with tf.device():`).
371
372
373  Args:
374    func: A Python function which accepts a list of `Tensor` objects
375      having element types that match the corresponding `tf.Tensor` objects
376      in `inp` and returns a list of `Tensor` objects (or a single
377      `Tensor`, or `None`) having element types that match the
378      corresponding values in `Tout`.
379    inp: A list of `Tensor` objects.
380    Tout: A list or tuple of tensorflow data types or a single tensorflow data
381      type if there is only one, indicating what `func` returns; an empty list
382      if no value is returned (i.e., if the return value is `None`).
383    name: A name for the operation (optional).
384
385  Returns:
386    A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
387    if `func` returns None.
388  """
389  return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
390
391
392@deprecation.deprecated(
393    date=None,
394    instructions="""tf.py_func is deprecated in TF V2. Instead, there are two
395    options available in V2.
396    - tf.py_function takes a python function which manipulates tf eager
397    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
398    an ndarray (just call tensor.numpy()) but having access to eager tensors
399    means `tf.py_function`s can use accelerators such as GPUs as well as
400    being differentiable using a gradient tape.
401    - tf.numpy_function maintains the semantics of the deprecated tf.py_func
402    (it is not differentiable, and manipulates numpy arrays). It drops the
403    stateful argument making all functions stateful.
404    """)
405@tf_export(v1=["py_func"])
406def py_func(func, inp, Tout, stateful=True, name=None):
407  """Wraps a python function and uses it as a TensorFlow op.
408
409  Given a python function `func`, which takes numpy arrays as its
410  arguments and returns numpy arrays as its outputs, wrap this function as an
411  operation in a TensorFlow graph. The following snippet constructs a simple
412  TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
413  in the graph:
414
415  ```python
416  def my_func(x):
417    # x will be a numpy array with the contents of the placeholder below
418    return np.sinh(x)
419  input = tf.placeholder(tf.float32)
420  y = tf.py_func(my_func, [input], tf.float32)
421  ```
422
423  **N.B.** The `tf.py_func()` operation has the following known limitations:
424
425  * The body of the function (i.e. `func`) will not be serialized in a
426    `GraphDef`. Therefore, you should not use this function if you need to
427    serialize your model and restore it in a different environment.
428
429  * The operation must run in the same address space as the Python program
430    that calls `tf.py_func()`. If you are using distributed TensorFlow, you
431    must run a `tf.train.Server` in the same process as the program that calls
432    `tf.py_func()` and you must pin the created operation to a device in that
433    server (e.g. using `with tf.device():`).
434
435  Args:
436    func: A Python function, which accepts `ndarray` objects as arguments and
437      returns a list of `ndarray` objects (or a single `ndarray`). This function
438      must accept as many arguments as there are tensors in `inp`, and these
439      argument types will match the corresponding `tf.Tensor` objects
440      in `inp`. The returns `ndarray`s must match the number and types defined
441      `Tout`.
442      Important Note: Input and output numpy `ndarray`s of `func` are not
443      guaranteed to be copies. In some cases their underlying memory will be
444      shared with the corresponding TensorFlow tensors.
445      In-place modification or storing `func` input or return values in
446      python datastructures without explicit (np.)copy
447      can have non-deterministic consequences.
448    inp: A list of `Tensor` objects.
449    Tout: A list or tuple of tensorflow data types or a single tensorflow data
450      type if there is only one, indicating what `func` returns.
451    stateful: (Boolean.) If True, the function should be considered stateful.
452      If a function is stateless, when given the same input it will return the
453      same output and have no observable side effects. Optimizations such as
454      common subexpression elimination are only performed on stateless
455      operations.
456    name: A name for the operation (optional).
457
458  Returns:
459    A list of `Tensor` or a single `Tensor` which `func` computes.
460  """
461  if context.executing_eagerly():
462    result = func(*[x.numpy() for x in inp])
463    result = nest.flatten(result)
464
465    result = [x if x is None else ops.convert_to_tensor(x) for x in result]
466    if len(result) == 1:
467      # Mimic the automatic unwrapping in graph-mode py_func
468      result, = result
469    return result
470
471  return _internal_py_func(
472      func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
473
474@tf_export("numpy_function", v1=[])
475def numpy_function(func, inp, Tout, name=None):
476  return py_func(func, inp, Tout, stateful=True, name=name)
477
478numpy_function.__doc__ = py_func.__doc__.replace(
479    "py_func", "numpy_function")
480
481
482ops.NotDifferentiable("PyFunc")
483ops.NotDifferentiable("PyFuncStateless")
484