• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Weight broadcasting operations.
16
17In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. This
18file includes operations for those broadcasting rules.
19"""
20
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import sets
27from tensorflow.python.util.tf_export import tf_export
28
29
30def _has_valid_dims(weights_shape, values_shape):
31  with ops.name_scope(
32      None, "has_invalid_dims", (weights_shape, values_shape)) as scope:
33    values_shape_2d = array_ops.expand_dims(values_shape, -1)
34    valid_dims = array_ops.concat(
35        (values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1)
36    weights_shape_2d = array_ops.expand_dims(weights_shape, -1)
37    invalid_dims = sets.set_difference(weights_shape_2d, valid_dims)
38    num_invalid_dims = array_ops.size(
39        invalid_dims.values, name="num_invalid_dims")
40    return math_ops.equal(0, num_invalid_dims, name=scope)
41
42
43def _has_valid_nonscalar_shape(
44    weights_rank, weights_shape, values_rank, values_shape):
45  with ops.name_scope(
46      None, "has_valid_nonscalar_shape",
47      (weights_rank, weights_shape, values_rank, values_shape)) as scope:
48    is_same_rank = math_ops.equal(
49        values_rank, weights_rank, name="is_same_rank")
50    return control_flow_ops.cond(
51        is_same_rank,
52        lambda: _has_valid_dims(weights_shape, values_shape),
53        lambda: is_same_rank,
54        name=scope)
55
56
57_ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values."
58
59
60def assert_broadcastable(weights, values):
61  """Asserts `weights` can be broadcast to `values`.
62
63  In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. We
64  let weights be either scalar, or the same rank as the target values, with each
65  dimension either 1, or the same as the corresponding values dimension.
66
67  Args:
68    weights: `Tensor` of weights.
69    values: `Tensor` of values to which weights are applied.
70
71  Returns:
72    `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape.
73    `no_op` if static checks determine `weights` has correct shape.
74
75  Raises:
76    ValueError:  If static checks determine `weights` has incorrect shape.
77  """
78  with ops.name_scope(None, "assert_broadcastable", (weights, values)) as scope:
79    with ops.name_scope(None, "weights", (weights,)) as weights_scope:
80      weights = ops.convert_to_tensor(weights, name=weights_scope)
81      weights_shape = array_ops.shape(weights, name="shape")
82      weights_rank = array_ops.rank(weights, name="rank")
83    weights_rank_static = tensor_util.constant_value(weights_rank)
84
85    with ops.name_scope(None, "values", (values,)) as values_scope:
86      values = ops.convert_to_tensor(values, name=values_scope)
87      values_shape = array_ops.shape(values, name="shape")
88      values_rank = array_ops.rank(values, name="rank")
89    values_rank_static = tensor_util.constant_value(values_rank)
90
91    # Try static checks.
92    if weights_rank_static is not None and values_rank_static is not None:
93      if weights_rank_static == 0:
94        return control_flow_ops.no_op(name="static_scalar_check_success")
95      if weights_rank_static != values_rank_static:
96        raise ValueError(
97            f"{_ASSERT_BROADCASTABLE_ERROR_PREFIX} values.rank="
98            f"{values_rank_static}. weights.rank={weights_rank_static}. "
99            f"values.shape={values.shape}. weights.shape={weights.shape}. "
100            f"Received weights={weights}, values={values}")
101      weights_shape_static = tensor_util.constant_value(weights_shape)
102      values_shape_static = tensor_util.constant_value(values_shape)
103      if weights_shape_static is not None and values_shape_static is not None:
104        # Sanity check, this should always be true since we checked rank above.
105        ndims = len(values_shape_static)
106        assert ndims == len(weights_shape_static)
107
108        for i in range(ndims):
109          if weights_shape_static[i] not in (1, values_shape_static[i]):
110            raise ValueError(
111                f"{_ASSERT_BROADCASTABLE_ERROR_PREFIX} Mismatch at dim {i}. "
112                f"values.shape={values_shape_static}, weights.shape="
113                f"{weights_shape_static}. Received weights={weights}, "
114                f"values={values}")
115        return control_flow_ops.no_op(name="static_dims_check_success")
116
117    # Dynamic checks.
118    is_scalar = math_ops.equal(0, weights_rank, name="is_scalar")
119    data = (
120        _ASSERT_BROADCASTABLE_ERROR_PREFIX,
121        "weights.shape=", weights.name, weights_shape,
122        "values.shape=", values.name, values_shape,
123        "is_scalar=", is_scalar,
124    )
125    is_valid_shape = control_flow_ops.cond(
126        is_scalar,
127        lambda: is_scalar,
128        lambda: _has_valid_nonscalar_shape(  # pylint: disable=g-long-lambda
129            weights_rank, weights_shape, values_rank, values_shape),
130        name="is_valid_shape")
131    return control_flow_ops.Assert(is_valid_shape, data, name=scope)
132
133
134@tf_export("__internal__.ops.broadcast_weights", v1=[])
135def broadcast_weights(weights, values):
136  """Broadcast `weights` to the same shape as `values`.
137
138  This returns a version of `weights` following the same broadcast rules as
139  `mul(weights, values)`, but limited to the weights shapes allowed by
140  `assert_broadcastable`. When computing a weighted average, use this function
141  to broadcast `weights` before summing them; e.g.,
142  `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
143
144  Args:
145    weights: `Tensor` whose shape is broadcastable to `values` according to the
146      rules of `assert_broadcastable`.
147    values: `Tensor` of any shape.
148
149  Returns:
150    `weights` broadcast to `values` shape according to the rules of
151      `assert_broadcastable`.
152  """
153  with ops.name_scope(None, "broadcast_weights", (weights, values)) as scope:
154    values = ops.convert_to_tensor(values, name="values")
155    weights = ops.convert_to_tensor(
156        weights, dtype=values.dtype.base_dtype, name="weights")
157
158    # Try static check for exact match.
159    weights_shape = weights.get_shape()
160    values_shape = values.get_shape()
161    if (weights_shape.is_fully_defined() and
162        values_shape.is_fully_defined() and
163        weights_shape.is_compatible_with(values_shape)):
164      return weights
165
166    # Skip the assert_broadcastable on TPU/GPU because asserts are not
167    # supported so it only causes unnecessary ops. Also skip it because it uses
168    # a DenseToDenseSetOperation op that is incompatible with the TPU/GPU when
169    # the shape(s) are dynamic.
170    if control_flow_ops.get_enclosing_xla_context() is not None:
171      return math_ops.multiply(
172          weights, array_ops.ones_like(values), name=scope)
173    with ops.control_dependencies((assert_broadcastable(weights, values),)):
174      return math_ops.multiply(
175          weights, array_ops.ones_like(values), name=scope)
176