1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_BCAST_H_ 17 #define TENSORFLOW_CORE_UTIL_BCAST_H_ 18 19 #include <algorithm> 20 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/lib/gtl/inlined_vector.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 // BCast is a helper for broadcasting binary tensor operation. 29 // TensorFlow's broadcasting rule follows that of numpy (See 30 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). 31 // 32 // The rule has the following properties: 33 // 34 // 1. suffix matching: the rule starts with the right-most 35 // dimension, and works towards the left-most dimension. Since 36 // TensorFlow is row-major, the right-most dimension (the last 37 // element in the shape of a tensor) is the inner-most, a.k.a. 38 // the fastest changing, dimension. 39 // 40 // 2. Two dimensions are compatible for broadcasting if both are the 41 // same or either is 1. 42 // 43 // BCast takes the shape of two tensors and computes a few vectors of 44 // int32 that are useful for the caller to reshape the tensors, apply 45 // the right broadcasts to them, compute the broadcasted operation, 46 // and possibly the gradients. In a nutshell, the caller is expected 47 // to compute the broadcasted operation as following: 48 // 49 // BCast b(x.shape(), y.shape()); 50 // output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) 51 // _op_ 52 // y.reshape(b.y_reshape()).broadcast(b.y_bcast()) 53 // 54 // For the gradient computation, 55 // grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) 56 // .reshape(x.shape()) 57 // grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) 58 // .reshape(y.shape()) 59 // backprop_x and backprop_y are functionals of the binary function "op", 60 // e.g., 61 // for +, backprop_x(x, y) = backprop_y(x, y) = 1; 62 // for *, backprop_x(x, y) = y, backprop_y(x, y) = x; 63 // for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; 64 // 65 // The multiplication in the grad * backprop_x itself is also 66 // broadcasting following the same rule. 67 // 68 // TODO(zhifengc): Adds support for n-ary (n >= 2). 69 class BCast { 70 public: 71 // A vector of int64 representing the shape of tensor. The 0-th 72 // element is the outer-most dimension and the last element is the 73 // inner-most dimension. Note that we do not use TensorShape since 74 // it's more convenient to manipulate Vec directly for this module. 75 typedef gtl::InlinedVector<int64, 4> Vec; 76 77 // Constructs all helper shapes, following the aforementioned rules. 78 // 79 // If "fewer_dims_optimization" is set to true (the default), the 80 // implementation tries to reduce intermediate dimensions needed to be more 81 // efficient. This is transparent to the caller. 82 // 83 // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have 84 // the same number of dimensions as the larger of the two inputs. 85 BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true); ~BCast()86 ~BCast() {} 87 88 // Returns true iff two operands are compatible according to the 89 // broadcasting rule. IsValid()90 bool IsValid() const { return valid_; } 91 92 // If and only if IsValid(), the following fields can be used in 93 // implementing a broadcasted binary tensor operation according to 94 // the broadcasting rule. x_reshape()95 const Vec& x_reshape() const { return x_reshape_; } x_bcast()96 const Vec& x_bcast() const { return x_bcast_; } y_reshape()97 const Vec& y_reshape() const { return y_reshape_; } y_bcast()98 const Vec& y_bcast() const { return y_bcast_; } result_shape()99 const Vec& result_shape() const { return result_; } output_shape()100 const Vec& output_shape() const { return output_; } grad_x_reduce_idx()101 const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } grad_y_reduce_idx()102 const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } 103 104 // Static helpers. 105 static Vec FromShape(const TensorShape& shape); 106 static TensorShape ToShape(const BCast::Vec& vec); 107 108 template <typename IndexType, int NDIMS> ToIndexArrayType(const BCast::Vec & vec)109 static Eigen::array<IndexType, NDIMS> ToIndexArrayType( 110 const BCast::Vec& vec) { 111 CHECK_EQ(vec.size(), NDIMS); 112 Eigen::array<IndexType, NDIMS> ret; 113 for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i]; 114 return ret; 115 } 116 117 template <int NDIMS> ToIndexArray(const BCast::Vec & vec)118 static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray( 119 const BCast::Vec& vec) { 120 return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec); 121 } 122 123 private: 124 bool valid_ = true; 125 Vec x_reshape_; 126 Vec x_bcast_; 127 Vec y_reshape_; 128 Vec y_bcast_; 129 Vec result_; 130 Vec output_; 131 Vec grad_x_reduce_idx_; 132 Vec grad_y_reduce_idx_; 133 134 static void Reverse(Vec* shape); 135 136 TF_DISALLOW_COPY_AND_ASSIGN(BCast); 137 }; 138 139 } // end namespace tensorflow 140 141 #endif // TENSORFLOW_CORE_UTIL_BCAST_H_ 142