• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/reduction_ops_common.h"
17 
18 #include "tensorflow/core/lib/strings/str_util.h"
19 
20 namespace tensorflow {
21 
out_reshape() const22 TensorShape ReductionHelper::out_reshape() const {
23   TensorShape shape;
24   for (auto size : out_reshape_) shape.AddDim(size);
25   return shape;
26 }
27 
28 // The final output shape must be allocated with this shape.
out_shape() const29 TensorShape ReductionHelper::out_shape() const {
30   TensorShape shape;
31   for (auto size : out_shape_) shape.AddDim(size);
32   return shape;
33 }
34 
shuffled_shape()35 TensorShape ReductionHelper::shuffled_shape() {
36   const int dims = data_reshape_.size();
37   TensorShape shape;
38   for (int i = reduce_first_axis_; i < dims; i += 2) {
39     shape.AddDim(data_reshape_[i]);
40   }
41   for (int i = !reduce_first_axis_; i < dims; i += 2) {
42     shape.AddDim(data_reshape_[i]);
43   }
44   return shape;
45 }
46 
permutation()47 gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
48   const int dims = data_reshape_.size();
49   const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
50   gtl::InlinedVector<int32, 8> perm(dims);
51   for (int i = 0; i < unreduced_dims; i++) {
52     perm[i] = 2 * i + reduce_first_axis_;
53   }
54   for (int i = unreduced_dims; i < dims; i++) {
55     perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
56   }
57   return perm;
58 }
59 
60 template <typename Tperm>
SimplifyHelper(const Tensor & data,const Tensor & axis,gtl::InlinedVector<bool,4> & bitmap)61 Status SimplifyHelper(const Tensor& data, const Tensor& axis,
62                       gtl::InlinedVector<bool, 4>& bitmap) {
63   auto axis_vec = axis.flat<Tperm>();
64   for (int64 i = 0; i < axis.NumElements(); ++i) {
65     Tperm index = axis_vec(i);
66     if (index < -data.dims() || index >= data.dims()) {
67       return errors::InvalidArgument("Invalid reduction dimension (", index,
68                                      " for input with ", data.dims(),
69                                      " dimension(s)");
70     }
71     index = (index + data.dims()) % data.dims();
72     bitmap[index] = true;
73   }
74   return Status::OK();
75 }
76 
Simplify(const Tensor & data,const Tensor & axis,const bool keep_dims)77 Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
78                                  const bool keep_dims) {
79   // bitmap[i] indicates whether to reduce data along i-th axis.
80   gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
81   if (axis.dtype() == DT_INT32) {
82     TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
83   } else {
84     TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
85   }
86   // Output tensor's dim sizes.
87   out_shape_.clear();
88   for (int i = 0; i < data.dims(); ++i) {
89     if (!bitmap[i]) {
90       // If we are not reducing along dimension i.
91       out_shape_.push_back(data.dim_size(i));
92     } else if (keep_dims) {
93       // We are reducing along dimension i, but we want to keep the
94       // same number of dimensions, so we set the dimension of i to
95       // '1'.
96       out_shape_.push_back(1);
97     }
98   }
99 
100   // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
101   // the input data before doing the reduction on the resulting
102   // tensor.  The shape of the reduction is a reshape of the final
103   // output.
104 
105   // We'll skip the leading 1s.
106   int dim_index = 0;
107   for (; dim_index < data.dims(); ++dim_index) {
108     if (data.dim_size(dim_index) != 1) break;
109   }
110   if (dim_index >= data.dims()) {
111     // Special case. The input is essentially a scalar.
112     reduce_first_axis_ = true;
113   } else {
114     // Starting from the (dim_index)-th dimension, dimensions
115     // alternates between runs that need to be reduced and runs that
116     // don't.
117     //
118     // NOTE: If a dimension has size 1, we group it as the current
119     // run so that we can minimize the number of runs.
120     //
121     // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
122     // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
123     // and reduce by axes = [1] (i.e., the output is shape [6]).
124     reduce_first_axis_ = bitmap[dim_index];
125     data_reshape_.push_back(data.dim_size(dim_index));
126     ++dim_index;
127     for (; dim_index < data.dims(); ++dim_index) {
128       const auto size = data.dim_size(dim_index);
129       if (size == 1) {
130         bitmap[dim_index] = bitmap[dim_index - 1];
131       }
132       if (bitmap[dim_index - 1] != bitmap[dim_index]) {
133         // Starts a new run of reduce or !reduce.
134         data_reshape_.push_back(size);
135       } else {
136         // Continue a run of reduce or !reduce.
137         data_reshape_.back() *= size;
138       }
139     }
140     // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
141     // are reduced), data_reshape_[1, 3, 5, ...]  is out_reshape_,
142     // otherwise, data_reshape_[0, 2, 4, ...] is.
143     for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
144          i += 2) {
145       out_reshape_.push_back(data_reshape_[i]);
146     }
147   }
148 
149   VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
150   VLOG(1) << "out  reshape: " << str_util::Join(out_reshape_, ",");
151   VLOG(1) << "out    shape: " << str_util::Join(out_shape_, ",");
152   return Status::OK();
153 }
154 
155 }  // namespace tensorflow
156