• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A library of common shape functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.python.framework import tensor_shape
23
24
25def _broadcast_shape_helper(shape_x, shape_y):
26  """Helper functions for is_broadcast_compatible and broadcast_shape.
27
28  Args:
29    shape_x: A `TensorShape`
30    shape_y: A `TensorShape`
31
32  Returns:
33    Returns None if the shapes are not broadcast compatible,
34    a list of the broadcast dimensions otherwise.
35  """
36  # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
37  # and pad with 1 to make them the same length.
38  broadcasted_dims = reversed(list(six.moves.zip_longest(
39      reversed(shape_x.dims),
40      reversed(shape_y.dims),
41      fillvalue=tensor_shape.Dimension(1))))
42  # Next we combine the dimensions according to the numpy broadcasting rules.
43  # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
44  return_dims = []
45  for (dim_x, dim_y) in broadcasted_dims:
46    if dim_x.value is None or dim_y.value is None:
47      # One or both dimensions is unknown. If either dimension is greater than
48      # 1, we assume that the program is correct, and the other dimension will
49      # be broadcast to match it.
50      # TODO(mrry): If we eliminate the shape checks in C++, we must still
51      # assert that the unknown dim is either 1 or the same as the known dim.
52      if dim_x.value is not None and dim_x.value > 1:
53        return_dims.append(dim_x)
54      elif dim_y.value is not None and dim_y.value > 1:
55        return_dims.append(dim_y)
56      else:
57        return_dims.append(None)
58    elif dim_x.value == 1:
59      # We will broadcast dim_x to dim_y.
60      return_dims.append(dim_y)
61    elif dim_y.value == 1:
62      # We will broadcast dim_y to dim_x.
63      return_dims.append(dim_x)
64    elif dim_x.value == dim_y.value:
65      # The dimensions are compatible, so output is the same size in that
66      # dimension.
67      return_dims.append(dim_x.merge_with(dim_y))
68    else:
69      return None
70  return return_dims
71
72
73def is_broadcast_compatible(shape_x, shape_y):
74  """Returns True if `shape_x` and `shape_y` are broadcast compatible.
75
76  Args:
77    shape_x: A `TensorShape`
78    shape_y: A `TensorShape`
79
80  Returns:
81    True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
82    to.  False otherwise.
83  """
84  if shape_x.ndims is None or shape_y.ndims is None:
85    return False
86  return _broadcast_shape_helper(shape_x, shape_y) is not None
87
88
89def broadcast_shape(shape_x, shape_y):
90  """Returns the broadcasted shape between `shape_x` and `shape_y`.
91
92  Args:
93    shape_x: A `TensorShape`
94    shape_y: A `TensorShape`
95
96  Returns:
97    A `TensorShape` representing the broadcasted shape.
98
99  Raises:
100    ValueError: If the two shapes can not be broadcasted.
101  """
102  if shape_x.ndims is None or shape_y.ndims is None:
103    return tensor_shape.unknown_shape()
104  return_dims = _broadcast_shape_helper(shape_x, shape_y)
105  if return_dims is None:
106    raise ValueError("Incompatible shapes for broadcasting: %s and %s"
107                     % (shape_x, shape_y))
108  return tensor_shape.TensorShape(return_dims)
109