• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
17 #define TENSORFLOW_CORE_UTIL_BCAST_H_
18 
19 #include <algorithm>
20 
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
23 #include "tensorflow/core/platform/macros.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 // BCast is a helper for broadcasting binary tensor operation.
29 // TensorFlow's broadcasting rule follows that of numpy (See
30 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
31 //
32 // The rule has the following properties:
33 //
34 //   1. suffix matching: the rule starts with the right-most
35 //      dimension, and works towards the left-most dimension. Since
36 //      TensorFlow is row-major, the right-most dimension (the last
37 //      element in the shape of a tensor) is the inner-most, a.k.a.
38 //      the fastest changing, dimension.
39 //
40 //   2. Two dimensions are compatible for broadcasting if both are the
41 //      same or either is 1.
42 //
43 // BCast takes the shape of two tensors and computes a few vectors of
44 // int32 that are useful for the caller to reshape the tensors, apply
45 // the right broadcasts to them, compute the broadcasted operation,
46 // and possibly the gradients. In a nutshell, the caller is expected
47 // to compute the broadcasted operation as following:
48 //
49 //   BCast b(x.shape(), y.shape());
50 //   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
51 //            _op_
52 //            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
53 //
54 // For the gradient computation,
55 //   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
56 //            .reshape(x.shape())
57 //   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
58 //            .reshape(y.shape())
59 // backprop_x and backprop_y are functionals of the binary function "op",
60 // e.g.,
61 //   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
62 //   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
63 //   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
64 //
65 // The multiplication in the grad * backprop_x itself is also
66 // broadcasting following the same rule.
67 //
68 // TODO(zhifengc): Adds support for n-ary (n >= 2).
69 class BCast {
70  public:
71   // A vector of int64 representing the shape of tensor. The 0-th
72   // element is the outer-most dimension and the last element is the
73   // inner-most dimension. Note that we do not use TensorShape since
74   // it's more convenient to manipulate Vec directly for this module.
75   typedef gtl::InlinedVector<int64, 4> Vec;
76 
77   // Constructs all helper shapes, following the aforementioned rules.
78   //
79   // If "fewer_dims_optimization" is set to true (the default), the
80   // implementation tries to reduce intermediate dimensions needed to be more
81   // efficient.  This is transparent to the caller.
82   //
83   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
84   // the same number of dimensions as the larger of the two inputs.
85   BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true);
~BCast()86   ~BCast() {}
87 
88   // Returns true iff two operands are compatible according to the
89   // broadcasting rule.
IsValid()90   bool IsValid() const { return valid_; }
91 
92   // If and only if IsValid(), the following fields can be used in
93   // implementing a broadcasted binary tensor operation according to
94   // the broadcasting rule.
x_reshape()95   const Vec& x_reshape() const { return x_reshape_; }
x_bcast()96   const Vec& x_bcast() const { return x_bcast_; }
y_reshape()97   const Vec& y_reshape() const { return y_reshape_; }
y_bcast()98   const Vec& y_bcast() const { return y_bcast_; }
result_shape()99   const Vec& result_shape() const { return result_; }
output_shape()100   const Vec& output_shape() const { return output_; }
grad_x_reduce_idx()101   const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
grad_y_reduce_idx()102   const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
103 
104   // Static helpers.
105   static Vec FromShape(const TensorShape& shape);
106   static TensorShape ToShape(const BCast::Vec& vec);
107 
108   template <typename IndexType, int NDIMS>
ToIndexArrayType(const BCast::Vec & vec)109   static Eigen::array<IndexType, NDIMS> ToIndexArrayType(
110       const BCast::Vec& vec) {
111     CHECK_EQ(vec.size(), NDIMS);
112     Eigen::array<IndexType, NDIMS> ret;
113     for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
114     return ret;
115   }
116 
117   template <int NDIMS>
ToIndexArray(const BCast::Vec & vec)118   static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
119       const BCast::Vec& vec) {
120     return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
121   }
122 
123  private:
124   bool valid_ = true;
125   Vec x_reshape_;
126   Vec x_bcast_;
127   Vec y_reshape_;
128   Vec y_bcast_;
129   Vec result_;
130   Vec output_;
131   Vec grad_x_reduce_idx_;
132   Vec grad_y_reduce_idx_;
133 
134   static void Reverse(Vec* shape);
135 
136   TF_DISALLOW_COPY_AND_ASSIGN(BCast);
137 };
138 
139 }  // end namespace tensorflow
140 
141 #endif  // TENSORFLOW_CORE_UTIL_BCAST_H_
142