1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Weight broadcasting operations. 16 17In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. This 18file includes operations for those broadcasting rules. 19""" 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import sets 27from tensorflow.python.util.tf_export import tf_export 28 29 30def _has_valid_dims(weights_shape, values_shape): 31 with ops.name_scope( 32 None, "has_invalid_dims", (weights_shape, values_shape)) as scope: 33 values_shape_2d = array_ops.expand_dims(values_shape, -1) 34 valid_dims = array_ops.concat( 35 (values_shape_2d, array_ops.ones_like(values_shape_2d)), axis=1) 36 weights_shape_2d = array_ops.expand_dims(weights_shape, -1) 37 invalid_dims = sets.set_difference(weights_shape_2d, valid_dims) 38 num_invalid_dims = array_ops.size( 39 invalid_dims.values, name="num_invalid_dims") 40 return math_ops.equal(0, num_invalid_dims, name=scope) 41 42 43def _has_valid_nonscalar_shape( 44 weights_rank, weights_shape, values_rank, values_shape): 45 with ops.name_scope( 46 None, "has_valid_nonscalar_shape", 47 (weights_rank, weights_shape, values_rank, values_shape)) as scope: 48 is_same_rank = math_ops.equal( 49 values_rank, weights_rank, name="is_same_rank") 50 return control_flow_ops.cond( 51 is_same_rank, 52 lambda: _has_valid_dims(weights_shape, values_shape), 53 lambda: is_same_rank, 54 name=scope) 55 56 57_ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values." 58 59 60def assert_broadcastable(weights, values): 61 """Asserts `weights` can be broadcast to `values`. 62 63 In `tf.losses` and `tf.metrics`, we support limited weight broadcasting. We 64 let weights be either scalar, or the same rank as the target values, with each 65 dimension either 1, or the same as the corresponding values dimension. 66 67 Args: 68 weights: `Tensor` of weights. 69 values: `Tensor` of values to which weights are applied. 70 71 Returns: 72 `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape. 73 `no_op` if static checks determine `weights` has correct shape. 74 75 Raises: 76 ValueError: If static checks determine `weights` has incorrect shape. 77 """ 78 with ops.name_scope(None, "assert_broadcastable", (weights, values)) as scope: 79 with ops.name_scope(None, "weights", (weights,)) as weights_scope: 80 weights = ops.convert_to_tensor(weights, name=weights_scope) 81 weights_shape = array_ops.shape(weights, name="shape") 82 weights_rank = array_ops.rank(weights, name="rank") 83 weights_rank_static = tensor_util.constant_value(weights_rank) 84 85 with ops.name_scope(None, "values", (values,)) as values_scope: 86 values = ops.convert_to_tensor(values, name=values_scope) 87 values_shape = array_ops.shape(values, name="shape") 88 values_rank = array_ops.rank(values, name="rank") 89 values_rank_static = tensor_util.constant_value(values_rank) 90 91 # Try static checks. 92 if weights_rank_static is not None and values_rank_static is not None: 93 if weights_rank_static == 0: 94 return control_flow_ops.no_op(name="static_scalar_check_success") 95 if weights_rank_static != values_rank_static: 96 raise ValueError( 97 f"{_ASSERT_BROADCASTABLE_ERROR_PREFIX} values.rank=" 98 f"{values_rank_static}. weights.rank={weights_rank_static}. " 99 f"values.shape={values.shape}. weights.shape={weights.shape}. " 100 f"Received weights={weights}, values={values}") 101 weights_shape_static = tensor_util.constant_value(weights_shape) 102 values_shape_static = tensor_util.constant_value(values_shape) 103 if weights_shape_static is not None and values_shape_static is not None: 104 # Sanity check, this should always be true since we checked rank above. 105 ndims = len(values_shape_static) 106 assert ndims == len(weights_shape_static) 107 108 for i in range(ndims): 109 if weights_shape_static[i] not in (1, values_shape_static[i]): 110 raise ValueError( 111 f"{_ASSERT_BROADCASTABLE_ERROR_PREFIX} Mismatch at dim {i}. " 112 f"values.shape={values_shape_static}, weights.shape=" 113 f"{weights_shape_static}. Received weights={weights}, " 114 f"values={values}") 115 return control_flow_ops.no_op(name="static_dims_check_success") 116 117 # Dynamic checks. 118 is_scalar = math_ops.equal(0, weights_rank, name="is_scalar") 119 data = ( 120 _ASSERT_BROADCASTABLE_ERROR_PREFIX, 121 "weights.shape=", weights.name, weights_shape, 122 "values.shape=", values.name, values_shape, 123 "is_scalar=", is_scalar, 124 ) 125 is_valid_shape = control_flow_ops.cond( 126 is_scalar, 127 lambda: is_scalar, 128 lambda: _has_valid_nonscalar_shape( # pylint: disable=g-long-lambda 129 weights_rank, weights_shape, values_rank, values_shape), 130 name="is_valid_shape") 131 return control_flow_ops.Assert(is_valid_shape, data, name=scope) 132 133 134@tf_export("__internal__.ops.broadcast_weights", v1=[]) 135def broadcast_weights(weights, values): 136 """Broadcast `weights` to the same shape as `values`. 137 138 This returns a version of `weights` following the same broadcast rules as 139 `mul(weights, values)`, but limited to the weights shapes allowed by 140 `assert_broadcastable`. When computing a weighted average, use this function 141 to broadcast `weights` before summing them; e.g., 142 `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 143 144 Args: 145 weights: `Tensor` whose shape is broadcastable to `values` according to the 146 rules of `assert_broadcastable`. 147 values: `Tensor` of any shape. 148 149 Returns: 150 `weights` broadcast to `values` shape according to the rules of 151 `assert_broadcastable`. 152 """ 153 with ops.name_scope(None, "broadcast_weights", (weights, values)) as scope: 154 values = ops.convert_to_tensor(values, name="values") 155 weights = ops.convert_to_tensor( 156 weights, dtype=values.dtype.base_dtype, name="weights") 157 158 # Try static check for exact match. 159 weights_shape = weights.get_shape() 160 values_shape = values.get_shape() 161 if (weights_shape.is_fully_defined() and 162 values_shape.is_fully_defined() and 163 weights_shape.is_compatible_with(values_shape)): 164 return weights 165 166 # Skip the assert_broadcastable on TPU/GPU because asserts are not 167 # supported so it only causes unnecessary ops. Also skip it because it uses 168 # a DenseToDenseSetOperation op that is incompatible with the TPU/GPU when 169 # the shape(s) are dynamic. 170 if control_flow_ops.get_enclosing_xla_context() is not None: 171 return math_ops.multiply( 172 weights, array_ops.ones_like(values), name=scope) 173 with ops.control_dependencies((assert_broadcastable(weights, values),)): 174 return math_ops.multiply( 175 weights, array_ops.ones_like(values), name=scope) 176