1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A library of common shape functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.framework import tensor_shape 23 24 25def _broadcast_shape_helper(shape_x, shape_y): 26 """Helper functions for is_broadcast_compatible and broadcast_shape. 27 28 Args: 29 shape_x: A `TensorShape` 30 shape_y: A `TensorShape` 31 32 Returns: 33 Returns None if the shapes are not broadcast compatible, 34 a list of the broadcast dimensions otherwise. 35 """ 36 # To compute the broadcasted dimensions, we zip together shape_x and shape_y, 37 # and pad with 1 to make them the same length. 38 broadcasted_dims = reversed(list(six.moves.zip_longest( 39 reversed(shape_x.dims), 40 reversed(shape_y.dims), 41 fillvalue=tensor_shape.Dimension(1)))) 42 # Next we combine the dimensions according to the numpy broadcasting rules. 43 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 44 return_dims = [] 45 for (dim_x, dim_y) in broadcasted_dims: 46 if dim_x.value is None or dim_y.value is None: 47 # One or both dimensions is unknown. If either dimension is greater than 48 # 1, we assume that the program is correct, and the other dimension will 49 # be broadcast to match it. 50 # TODO(mrry): If we eliminate the shape checks in C++, we must still 51 # assert that the unknown dim is either 1 or the same as the known dim. 52 if dim_x.value is not None and dim_x.value > 1: 53 return_dims.append(dim_x) 54 elif dim_y.value is not None and dim_y.value > 1: 55 return_dims.append(dim_y) 56 else: 57 return_dims.append(None) 58 elif dim_x.value == 1: 59 # We will broadcast dim_x to dim_y. 60 return_dims.append(dim_y) 61 elif dim_y.value == 1: 62 # We will broadcast dim_y to dim_x. 63 return_dims.append(dim_x) 64 elif dim_x.value == dim_y.value: 65 # The dimensions are compatible, so output is the same size in that 66 # dimension. 67 return_dims.append(dim_x.merge_with(dim_y)) 68 else: 69 return None 70 return return_dims 71 72 73def is_broadcast_compatible(shape_x, shape_y): 74 """Returns True if `shape_x` and `shape_y` are broadcast compatible. 75 76 Args: 77 shape_x: A `TensorShape` 78 shape_y: A `TensorShape` 79 80 Returns: 81 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted 82 to. False otherwise. 83 """ 84 if shape_x.ndims is None or shape_y.ndims is None: 85 return False 86 return _broadcast_shape_helper(shape_x, shape_y) is not None 87 88 89def broadcast_shape(shape_x, shape_y): 90 """Returns the broadcasted shape between `shape_x` and `shape_y`. 91 92 Args: 93 shape_x: A `TensorShape` 94 shape_y: A `TensorShape` 95 96 Returns: 97 A `TensorShape` representing the broadcasted shape. 98 99 Raises: 100 ValueError: If the two shapes can not be broadcasted. 101 """ 102 if shape_x.ndims is None or shape_y.ndims is None: 103 return tensor_shape.unknown_shape() 104 return_dims = _broadcast_shape_helper(shape_x, shape_y) 105 if return_dims is None: 106 raise ValueError("Incompatible shapes for broadcasting: %s and %s" 107 % (shape_x, shape_y)) 108 return tensor_shape.TensorShape(return_dims) 109