• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for computing default gradients."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import resource_variable_ops
24
25
26def get_zeros_dtype(t):
27  """Return the dtype for the default gradient for a Tensor."""
28  if t.dtype == dtypes.resource:
29    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
30    if (handle_data is None or not handle_data.is_set or
31        len(handle_data.shape_and_type) != 1):
32      raise ValueError("Internal error: Tried to take gradients (or similar) "
33                       "of a variable without handle data:\n%s" % str(t))
34    return handle_data.shape_and_type[0].dtype
35  return t.dtype
36
37
38def shape_and_dtype(t):
39  """Return the shape and dtype for the default gradient for a Tensor."""
40  if t.dtype == dtypes.resource:
41    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
42    if (handle_data is None or not handle_data.is_set or
43        len(handle_data.shape_and_type) != 1):
44      raise ValueError("Internal error: Tried to take gradients (or similar) "
45                       "of a variable without handle data:\n%s" % str(t))
46    shape_and_type = handle_data.shape_and_type[0]
47    return (tensor_shape.TensorShape(shape_and_type.shape),
48            dtypes.as_dtype(shape_and_type.dtype))
49  return t.shape, t.dtype
50
51
52def zeros_like(t):
53  """Like array_ops.zeros_like, but respects resource handles."""
54  if t.dtype == dtypes.resource:
55    return array_ops.zeros(*shape_and_dtype(t))
56  else:
57    return array_ops.zeros_like(t)
58
59
60def ones_like(t):
61  """Like array_ops.ones_like, but respects resource handles."""
62  if t.dtype == dtypes.resource:
63    return array_ops.ones(*shape_and_dtype(t))
64  else:
65    return array_ops.ones_like(t)
66
67
68def supports_default_grad(t):
69  """Whether tensor `t` supports creating a default gradient.
70
71  This function assumes that `t` is of a trainable type.
72
73  Args:
74    t: Tensor
75
76  Returns:
77    Bool
78  """
79  if t.dtype == dtypes.resource:
80    handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
81    if (handle_data is None or not handle_data.is_set or
82        len(handle_data.shape_and_type) != 1):
83      return False
84  return True
85