• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Script Language Operators."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import threading
23
24# Used by py_util.cc to get tracebacks.
25import traceback  # pylint: disable=unused-import
26import weakref
27
28import numpy as np
29import six
30
31from tensorflow.python.eager import backprop
32from tensorflow.python.eager import context
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import func_graph
35from tensorflow.python.framework import function
36from tensorflow.python.framework import ops
37from tensorflow.python.lib.core import _pywrap_py_func
38from tensorflow.python.ops import gen_script_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.util import compat
41from tensorflow.python.util import deprecation
42from tensorflow.python.util import dispatch
43from tensorflow.python.util import lazy_loader
44from tensorflow.python.util import nest
45from tensorflow.python.util import tf_inspect
46from tensorflow.python.util.tf_export import tf_export
47
48autograph = lazy_loader.LazyLoader(
49    "autograph", globals(),
50    "tensorflow.python.autograph.impl.api")
51
52
53# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
54# used for differentiation.
55tape_cache = {}
56
57
58def _maybe_copy_to_context_device(tensor, device_name):
59  """Copy an EagerTensor to the current device if it's not on `device_name`."""
60  in_device = tensor.backing_device
61  if device_name == in_device:
62    return tensor
63  else:
64    # Note that EagerTensor._copy bypasses the placer and copies to the context
65    # device, which means e.g. int32 Tensors which would normally be forced onto
66    # the CPU can instead be placed on the GPU. This is necessary so that the
67    # PyFunc kernel always returns Tensors on the device it's executing on.
68    return tensor._copy()  # pylint: disable=protected-access
69
70
71class EagerFunc(object):
72  """A wrapper for a function owned by an EagerPyFunc."""
73
74  def __init__(self, func, Tout, is_grad_func, use_tape_cache=True):
75    """Constructs an EagerFunc.
76
77    Args:
78      func: The function to wrap.
79      Tout: A list of datatypes for the output; an empty list if the output is
80        None.
81      is_grad_func: Whether this EagerFunc is the gradient of another
82        EagerPyFunc.
83      use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
84        For additional information, see description of `_eager_py_func`.
85        This parameter should be removed once the #35084 issue is fixed.
86    """
87    self._func = func
88    self._out_dtypes = Tout
89    self._is_grad_func = is_grad_func
90    self._use_tape_cache = use_tape_cache
91
92  def _convert(self, value, dtype):
93    """Converts `value` to a tensor of type `dtype`, with error checking.
94
95    Args:
96      value: The tensor to convert.
97      dtype: The desired dtype.
98
99    Returns:
100      A tensor of type `dtype`, or a zeros tensor if value is None and
101      this function is in fact a gradient function.
102
103    Raises:
104      RuntimeError: if `value` is a variable.
105    """
106
107    if isinstance(value, resource_variable_ops.ResourceVariable):
108      raise RuntimeError(
109          "Attempting to return a variable from an eagerly executed py_func. "
110          "Only numeric data structures like Tensors or NumPy arrays should "
111          "be returned; to return the value of a variable, make sure to obtain "
112          "the Tensor backing it by calling `.read_value()` on the variable in "
113          "question: %s" % value)
114    if value is None and self._is_grad_func:
115      # Gradient functions may legitimately return a list that contains
116      # both Tensors and Python Nones. Unfortunately this breaks the
117      # OpKernel, so for now we replace None objects with zeros, which is
118      # mathematically correct but will prevent short-circuiting gradient
119      # computations.
120      #
121      # TODO(akshayka): Make it possible to return a list of both Tensors and
122      # Nones from an EagerPyFunc.
123      return constant_op.constant(0.0, dtype=dtype)
124    return ops.convert_to_tensor(value, dtype=dtype)
125
126  def __call__(self, device, token, args):
127    """Passes `args` to `self._func`, which is executed eagerly."""
128
129    with context.eager_mode(), backprop.GradientTape() as tape:
130      # Only watch tensors with a floating or complex dtype.
131      for tensor in args:
132        for t in nest.flatten(tensor):
133          if t.dtype.is_floating or t.dtype.is_complex:
134            tape.watch(t)
135      ret = self._func(*args)
136      # copy the returned tensors to the PyFunc op's device if necessary.
137      device_name = device
138      if device_name is None:
139        # "None" here means "CPU", from the nullptr convention with C++ device
140        # pointers.
141        device_name = "/job:localhost/replica:0/task:0/device:CPU:0"
142      with ops.device(device):
143        if isinstance(ret, (tuple, list)):
144          outputs = [
145              _maybe_copy_to_context_device(self._convert(x, dtype=dtype),
146                                            device_name)
147              for (x, dtype) in zip(ret, self._out_dtypes)
148          ]
149        elif ret is None:
150          outputs = None
151        else:
152          outputs = _maybe_copy_to_context_device(
153              self._convert(ret, dtype=self._out_dtypes[0]), device_name)
154    if self._use_tape_cache:
155      tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
156    return outputs
157
158
159class FuncRegistry(object):
160  """A helper class to keep track of registered py functions.
161
162  FuncRegistry keeps a map from unique tokens (string) to python
163  functions, which takes numpy arrays and outputs numpy arrays.
164  """
165
166  def __init__(self):
167    self._lock = threading.Lock()
168    self._unique_id = 0  # GUARDED_BY(self._lock)
169    # Only store weakrefs to the functions. The strong reference is stored in
170    # the graph.
171    self._funcs = weakref.WeakValueDictionary()
172
173  @property
174  def _ctx(self):
175    # N.B. This is needed to support calling py_func with GPU tensors,
176    # which must be transferred to CPU if used in any of the NumPy APIs.
177    context.ensure_initialized()
178    return context.context()._handle  # pylint: disable=protected-access
179
180  def insert(self, func):
181    """Registers `func` and returns a unique token for this entry."""
182    token = self._next_unique_token()
183    # Store a weakref to the function
184    self._funcs[token] = func
185    return token
186
187  def remove(self, token):
188    """Removes the registered function corresponding to `token`."""
189    self._funcs.pop(token, None)
190
191  @staticmethod
192  def _convert(value, dtype=None):
193    """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
194
195    Numpy pads with zeros when using string and unicode dtypes if different
196    components of a tensor have different lengths.  This is bad: ignoring the
197    padding is wrong for text data, and removing the padding is wrong for binary
198    data.  To avoid this bug, we redo the conversion using an object dtype.
199    Additionally, we convert unicode strings to (byte-)strings for
200    compatibility.
201
202    Args:
203      value: Value to convert to a numpy array.
204      dtype: (Optional.) Desired NumPy type for the returned value.
205
206    Returns:
207      A numpy array.
208    """
209    result = np.asarray(value, dtype=dtype, order="C")
210    if result.dtype.char == "S" and result is not value:
211      return np.asarray(value, order="C", dtype=object)
212    elif result.dtype.char == "U" and result is not value:
213      value = np.vectorize(lambda x: x.encode("utf8"))(value)
214      return np.asarray(value, order="C", dtype=object)
215    elif result.dtype.char == "U":
216      return result.astype(np.bytes_)
217    else:
218      return result
219
220  def __call__(self, token, device, args):
221    """Calls the registered function for `token` with args.
222
223    Args:
224      token: A key into this `FuncRegistry` identifying which function to call.
225      device: Name of the device on which outputs of `token`'s corresponding
226        operation should be placed. Used iff the function registered for `token`
227        is an EagerPyFunc.
228      args: The arguments to pass to the function registered for `token`.
229
230    Returns:
231      The output of the function registered for `token`.
232
233    Raises:
234      ValueError: if no function is registered for `token`.
235    """
236    func = self._funcs.get(token, None)
237    if func is None:
238      raise ValueError("callback %s is not found" % token)
239    if isinstance(func, EagerFunc):
240      # NB: Different invocations of the same py_func will share the same
241      # token, and the entries they stash in the tape_cache will collide.
242      # In practice, when executing a graph, this should only happen if
243      # the py_func is in a while_loop whose iterations are run in parallel
244      # or if the graph is being driven by concurrent session.run() calls.
245      #
246      # TODO(akshayka): Key the tape cache in a thread-safe way.
247      return func(device, token, args)
248    else:
249      ret = func(*args)
250      # Strings seem to lead to a memory leak here if they're not wrapped in a
251      # list.
252      if isinstance(ret, six.binary_type):
253        ret = [ret]
254      # Ensures that we return either a single numpy array or a list of numpy
255      # arrays.
256      if isinstance(ret, (tuple, list)):
257        return [self._convert(x) for x in ret]
258      else:
259        return self._convert(ret)
260
261  def size(self):
262    """Returns how many functions are currently registered."""
263    return len(self._funcs)
264
265  def _next_unique_token(self):
266    """Returns a unique token."""
267    with self._lock:
268      uid = self._unique_id
269      self._unique_id += 1
270    return "pyfunc_%d" % uid
271
272
273# Global registry for py functions.
274_py_funcs = FuncRegistry()
275
276_pywrap_py_func.initialize_py_trampoline(_py_funcs)
277
278
279def _internal_py_func(func,
280                      inp,
281                      Tout,
282                      stateful=None,
283                      eager=False,
284                      is_grad_func=False,
285                      name=None,
286                      use_tape_cache=True):
287  """See documentation for py_func and eager_py_func."""
288  if not callable(func):
289    raise ValueError("Expected func to be callable, got func of type {}".format(
290        type(func)))
291
292  original_func = func
293  func = autograph.do_not_convert(func)
294
295  is_list_or_tuple = False
296  if isinstance(Tout, (list, tuple)):
297    is_list_or_tuple = True
298  else:
299    Tout = [Tout]
300
301  if eager:
302    func = EagerFunc(func, Tout, is_grad_func, use_tape_cache=use_tape_cache)
303
304  # Tying the registered function's lifetime with the current default graph is
305  # not reliable. For example, Estimator-based binaries may switch graphs in
306  # between model training end evaluation, via saved_model. Those binaries work
307  # because the original function is global, and break once the registered
308  # function is an anonymous lambda, like the one produced by do_not_convert.
309  # To avoid breaking those cases, we attach the wrapper to the original
310  # function so that their lifetime is connected.
311  # TODO(b/144286616): Remove this.
312  if tf_inspect.isfunction(original_func):
313    # Note: this check is needed because original_func may be a descriptor
314    # (https://docs.python.org/3/howto/descriptor.html)
315    # and we can't attach attributes to those.
316    original_func.ag_dnc_wrapper__ = func
317
318  token = _py_funcs.insert(func)
319  # We tie the registered function's lifetime with the current default graph,
320  # i.e., when the current graph is destroyed, we remove its py funcs.
321  graph = ops.get_default_graph()
322
323  while True:
324    current_graph = graph
325    if isinstance(graph, function._FuncGraph):  # pylint: disable=protected-access
326      graph = graph._outer_graph  # pylint: disable=protected-access
327    elif isinstance(graph, func_graph.FuncGraph):
328      graph = graph.outer_graph
329    if graph is current_graph:
330      break
331
332  # TODO(zhifengc): Consider adding a Graph method to collect
333  # `cleanup` objects in one of its member.
334  if not hasattr(graph, "_py_funcs_used_in_graph"):
335    graph._py_funcs_used_in_graph = []  # pylint: disable=protected-access
336
337  # Store a reference to the function in the graph to ensure it stays alive
338  # as long as the graph lives. When the graph is destroyed, the function
339  # is left to the garbage collector for destruction as well.
340  graph._py_funcs_used_in_graph.append(func)  # pylint: disable=protected-access
341
342  if eager:
343    result = gen_script_ops.eager_py_func(
344        input=inp,
345        token=token,
346        is_async=context.is_async(),
347        Tout=Tout,
348        name=name)
349  else:
350    if stateful:
351      result = gen_script_ops.py_func(
352          input=inp, token=token, Tout=Tout, name=name)
353    else:
354      result = gen_script_ops.py_func_stateless(
355          input=inp, token=token, Tout=Tout, name=name)
356  return result if is_list_or_tuple else result[0]
357
358
359# TODO(akshayka): Implement higher-order derivatives.
360@ops.RegisterGradient("EagerPyFunc")
361def _EagerPyFuncGrad(op, *dy):
362  """Computes the gradient of an EagerPyFunc."""
363
364  token = op.get_attr("token")
365
366  def eagerly_executed_grad(*dy):
367    tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
368    return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
369
370  with ops.control_dependencies(op.outputs):
371    return _internal_py_func(
372        func=eagerly_executed_grad,
373        inp=dy,
374        Tout=[tensor.dtype for tensor in op.inputs],
375        eager=True,
376        is_grad_func=True)
377
378
379def _eager_py_func(func, inp, Tout, name=None, use_tape_cache=True):
380  """Wraps a python function into a TensorFlow op that executes it eagerly.
381
382  This function is the internal implementation for `eager_py_func`, see the
383  `eager_py_func` docstring for the full description.
384
385  Note: this function as a layer of indirection was added with one
386  specific purpose: as a workaround for github issue #35084.
387  It does all the same as `eager_py_func` used to do with one difference:
388  it can be used to instruct underlying EagerFunc not to use `tape_cache`
389  to avoid memory leak. When the issue #35084 is fixed - this function should
390  be removed, its body should be moved back to become the body of
391  `eager_py_func` and all the call sites should be reverted to
392  using `eager_py_func` without `use_tape_cache` argument of any value.
393
394  Args:
395    func: A Python function which accepts a list of `Tensor` objects having
396      element types that match the corresponding `tf.Tensor` objects in `inp`
397      and returns a list of `Tensor` objects (or a single `Tensor`, or `None`)
398      having element types that match the corresponding values in `Tout`.
399    inp: A list of `Tensor` objects.
400    Tout: A list or tuple of tensorflow data types or a single tensorflow data
401      type if there is only one, indicating what `func` returns; an empty list
402      if no value is returned (i.e., if the return value is `None`).
403    name: A name for the operation (optional).
404    use_tape_cache: (Optional.) Whether to cache `func` in the `tape_cache`.
405      For additional information, see description of `_eager_py_func`.
406      This parameter should be removed once the #35084 issue is fixed.
407
408  Returns:
409    A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
410    if `func` returns None.
411  """
412  if ops.executing_eagerly_outside_functions():
413    with ops.device(context.context().host_address_space()):
414      return _internal_py_func(
415          func=func,
416          inp=inp,
417          Tout=Tout,
418          eager=True,
419          name=name,
420          use_tape_cache=use_tape_cache)
421
422  return _internal_py_func(
423      func=func,
424      inp=inp,
425      Tout=Tout,
426      eager=True,
427      name=name,
428      use_tape_cache=use_tape_cache)
429
430
431@tf_export("py_function")
432@dispatch.add_dispatch_support
433def eager_py_func(func, inp, Tout, name=None):
434  """Wraps a python function into a TensorFlow op that executes it eagerly.
435
436  This function allows expressing computations in a TensorFlow graph as
437  Python functions. In particular, it wraps a Python function `func`
438  in a once-differentiable TensorFlow operation that executes it with eager
439  execution enabled. As a consequence, `tf.py_function` makes it
440  possible to express control flow using Python constructs (`if`, `while`,
441  `for`, etc.), instead of TensorFlow control flow constructs (`tf.cond`,
442  `tf.while_loop`). For example, you might use `tf.py_function` to
443  implement the log huber function:
444
445  ```python
446  def log_huber(x, m):
447    if tf.abs(x) <= m:
448      return x**2
449    else:
450      return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
451
452  x = tf.compat.v1.placeholder(tf.float32)
453  m = tf.compat.v1.placeholder(tf.float32)
454
455  y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
456  dy_dx = tf.gradients(y, x)[0]
457
458  with tf.compat.v1.Session() as sess:
459    # The session executes `log_huber` eagerly. Given the feed values below,
460    # it will take the first branch, so `y` evaluates to 1.0 and
461    # `dy_dx` evaluates to 2.0.
462    y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
463  ```
464
465  You can also use `tf.py_function` to debug your models at runtime
466  using Python tools, i.e., you can isolate portions of your code that
467  you want to debug, wrap them in Python functions and insert `pdb` tracepoints
468  or print statements as desired, and wrap those functions in
469  `tf.py_function`.
470
471  For more information on eager execution, see the
472  [Eager guide](https://tensorflow.org/guide/eager).
473
474  `tf.py_function` is similar in spirit to `tf.compat.v1.py_func`, but unlike
475  the latter, the former lets you use TensorFlow operations in the wrapped
476  Python function. In particular, while `tf.compat.v1.py_func` only runs on CPUs
477  and
478  wraps functions that take NumPy arrays as inputs and return NumPy arrays as
479  outputs, `tf.py_function` can be placed on GPUs and wraps functions
480  that take Tensors as inputs, execute TensorFlow operations in their bodies,
481  and return Tensors as outputs.
482
483  Like `tf.compat.v1.py_func`, `tf.py_function` has the following limitations
484  with respect to serialization and distribution:
485
486  * The body of the function (i.e. `func`) will not be serialized in a
487    `GraphDef`. Therefore, you should not use this function if you need to
488    serialize your model and restore it in a different environment.
489
490  * The operation must run in the same address space as the Python program
491    that calls `tf.py_function()`. If you are using distributed
492    TensorFlow, you must run a `tf.distribute.Server` in the same process as the
493    program that calls `tf.py_function()` and you must pin the created
494    operation to a device in that server (e.g. using `with tf.device():`).
495
496
497  Args:
498    func: A Python function which accepts a list of `Tensor` objects having
499      element types that match the corresponding `tf.Tensor` objects in `inp`
500      and returns a list of `Tensor` objects (or a single `Tensor`, or `None`)
501      having element types that match the corresponding values in `Tout`.
502    inp: A list of `Tensor` objects.
503    Tout: A list or tuple of tensorflow data types or a single tensorflow data
504      type if there is only one, indicating what `func` returns; an empty list
505      if no value is returned (i.e., if the return value is `None`).
506    name: A name for the operation (optional).
507
508  Returns:
509    A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
510    if `func` returns None.
511  """
512  return _eager_py_func(
513      func=func, inp=inp, Tout=Tout, name=name, use_tape_cache=True)
514
515
516def py_func_common(func, inp, Tout, stateful=True, name=None):
517  """Wraps a python function and uses it as a TensorFlow op.
518
519  Given a python function `func`, which takes numpy arrays as its
520  arguments and returns numpy arrays as its outputs, wrap this function as an
521  operation in a TensorFlow graph. The following snippet constructs a simple
522  TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
523  in the graph:
524
525  ```python
526  def my_func(x):
527    # x will be a numpy array with the contents of the placeholder below
528    return np.sinh(x)
529  input = tf.compat.v1.placeholder(tf.float32)
530  y = tf.compat.v1.py_func(my_func, [input], tf.float32)
531  ```
532
533  **N.B.** The `tf.compat.v1.py_func()` operation has the following known
534  limitations:
535
536  * The body of the function (i.e. `func`) will not be serialized in a
537    `GraphDef`. Therefore, you should not use this function if you need to
538    serialize your model and restore it in a different environment.
539
540  * The operation must run in the same address space as the Python program
541    that calls `tf.compat.v1.py_func()`. If you are using distributed
542    TensorFlow, you
543    must run a `tf.distribute.Server` in the same process as the program that
544    calls
545    `tf.compat.v1.py_func()` and you must pin the created operation to a device
546    in that
547    server (e.g. using `with tf.device():`).
548
549  Note: It produces tensors of unknown shape and rank as shape inference
550    does not work on arbitrary Python code.
551    If you need the shape, you need to set it based on statically
552    available information.
553
554    E.g.
555    ```python
556    import tensorflow as tf
557    import numpy as np
558
559    def make_synthetic_data(i):
560        return np.cast[np.uint8](i) * np.ones([20,256,256,3],
561                dtype=np.float32) / 10.
562
563    def preprocess_fn(i):
564        ones = tf.py_function(make_synthetic_data,[i],tf.float32)
565        ones.set_shape(tf.TensorShape([None, None, None, None]))
566        ones = tf.image.resize(ones, [224,224])
567        return ones
568
569    ds = tf.data.Dataset.range(10)
570    ds = ds.map(preprocess_fn)
571    ```
572
573  Args:
574    func: A Python function, which accepts `ndarray` objects as arguments and
575      returns a list of `ndarray` objects (or a single `ndarray`). This function
576      must accept as many arguments as there are tensors in `inp`, and these
577      argument types will match the corresponding `tf.Tensor` objects in `inp`.
578      The returns `ndarray`s must match the number and types defined `Tout`.
579      Important Note: Input and output numpy `ndarray`s of `func` are not
580        guaranteed to be copies. In some cases their underlying memory will be
581        shared with the corresponding TensorFlow tensors. In-place modification
582        or storing `func` input or return values in python datastructures
583        without explicit (np.)copy can have non-deterministic consequences.
584    inp: A list of `Tensor` objects.
585    Tout: A list or tuple of tensorflow data types or a single tensorflow data
586      type if there is only one, indicating what `func` returns.
587    stateful: (Boolean.) If True, the function should be considered stateful. If
588      a function is stateless, when given the same input it will return the same
589      output and have no observable side effects. Optimizations such as common
590      subexpression elimination are only performed on stateless operations.
591    name: A name for the operation (optional).
592
593  Returns:
594    A list of `Tensor` or a single `Tensor` which `func` computes.
595  """
596  if context.executing_eagerly():
597    result = func(*[np.array(x) for x in inp])
598    result = nest.flatten(result)
599
600    result = [x if x is None else ops.convert_to_tensor(x) for x in result]
601    if len(result) == 1:
602      # Mimic the automatic unwrapping in graph-mode py_func
603      result, = result
604    return result
605
606  if ops.executing_eagerly_outside_functions():
607    with ops.device(context.context().host_address_space()):
608      return _internal_py_func(
609          func=func,
610          inp=inp,
611          Tout=Tout,
612          stateful=stateful,
613          eager=False,
614          name=name)
615
616  return _internal_py_func(
617      func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
618
619
620@deprecation.deprecated(
621    date=None,
622    instructions="""tf.py_func is deprecated in TF V2. Instead, there are two
623    options available in V2.
624    - tf.py_function takes a python function which manipulates tf eager
625    tensors instead of numpy arrays. It's easy to convert a tf eager tensor to
626    an ndarray (just call tensor.numpy()) but having access to eager tensors
627    means `tf.py_function`s can use accelerators such as GPUs as well as
628    being differentiable using a gradient tape.
629    - tf.numpy_function maintains the semantics of the deprecated tf.py_func
630    (it is not differentiable, and manipulates numpy arrays). It drops the
631    stateful argument making all functions stateful.
632    """)
633@tf_export(v1=["py_func"])
634@dispatch.add_dispatch_support
635def py_func(func, inp, Tout, stateful=True, name=None):
636  return py_func_common(func, inp, Tout, stateful, name=name)
637
638
639py_func.__doc__ = "%s" % py_func_common.__doc__
640
641
642@tf_export("numpy_function")
643@dispatch.add_dispatch_support
644def numpy_function(func, inp, Tout, name=None):
645  """Wraps a python function and uses it as a TensorFlow op.
646
647  Given a python function `func` wrap this function as an operation in a
648  TensorFlow function. `func` must take numpy arrays as its arguments and
649  return numpy arrays as its outputs.
650
651  The following example creates a TensorFlow graph with `np.sinh()` as an
652  operation in the graph:
653
654  >>> def my_numpy_func(x):
655  ...   # x will be a numpy array with the contents of the input to the
656  ...   # tf.function
657  ...   return np.sinh(x)
658  >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
659  ... def tf_function(input):
660  ...   y = tf.numpy_function(my_numpy_func, [input], tf.float32)
661  ...   return y * y
662  >>> tf_function(tf.constant(1.))
663  <tf.Tensor: shape=(), dtype=float32, numpy=1.3810978>
664
665  Comparison to `tf.py_function`:
666  `tf.py_function` and `tf.numpy_function` are very similar, except that
667  `tf.numpy_function` takes numpy arrays, and not `tf.Tensor`s. If you want the
668  function to contain `tf.Tensors`, and have any TensorFlow operations executed
669  in the function be differentiable, please use `tf.py_function`.
670
671  Note: The `tf.numpy_function` operation has the following known
672  limitations:
673
674  * The body of the function (i.e. `func`) will not be serialized in a
675    `tf.SavedModel`. Therefore, you should not use this function if you need to
676    serialize your model and restore it in a different environment.
677
678  * The operation must run in the same address space as the Python program
679    that calls `tf.numpy_function()`. If you are using distributed
680    TensorFlow, you must run a `tf.distribute.Server` in the same process as the
681    program that calls `tf.numpy_function`  you must pin the created
682    operation to a device in that server (e.g. using `with tf.device():`).
683
684  * Since the function takes numpy arrays, you cannot take gradients
685    through a numpy_function. If you require something that is differentiable,
686    please consider using tf.py_function.
687
688  * The resulting function is assumed stateful and will never be optimized.
689
690  Args:
691    func: A Python function, which accepts `numpy.ndarray` objects as arguments
692      and returns a list of `numpy.ndarray` objects (or a single
693      `numpy.ndarray`). This function must accept as many arguments as there are
694      tensors in `inp`, and these argument types will match the corresponding
695      `tf.Tensor` objects in `inp`. The returns `numpy.ndarray`s must match the
696      number and types defined `Tout`.
697      Important Note: Input and output `numpy.ndarray`s of `func` are not
698        guaranteed to be copies. In some cases their underlying memory will be
699        shared with the corresponding TensorFlow tensors. In-place modification
700        or storing `func` input or return values in python datastructures
701        without explicit (np.)copy can have non-deterministic consequences.
702    inp: A list of `tf.Tensor` objects.
703    Tout: A list or tuple of tensorflow data types or a single tensorflow data
704      type if there is only one, indicating what `func` returns.
705    name: (Optional) A name for the operation.
706
707  Returns:
708    Single or list of `tf.Tensor` which `func` computes.
709  """
710  return py_func_common(func, inp, Tout, stateful=True, name=name)
711
712
713ops.NotDifferentiable("PyFunc")
714ops.NotDifferentiable("PyFuncStateless")
715