• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/string_ops.cc.
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/kernel_def_builder.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
GetStrides(const TensorShape & shape)34 const gtl::InlinedVector<int64, 8> GetStrides(const TensorShape& shape) {
35   gtl::InlinedVector<int64, 8> result(shape.dims());
36   int64 product = 1;
37   for (int32 i = shape.dims() - 1; i >= 0; --i) {
38     result[i] = product;
39     product *= shape.dim_size(i);
40   }
41   return result;
42 }
43 
44 // Given a linear index to a subset of dimensions, full shape,
45 // precomputed list of running products of the full shape, and list of
46 // dimensions in the subset, outputs the linear index to the full shape with
47 // nonspecified dimensions set to 0.  Dimensions must be ordered from outer-most
48 // to inner-most with respect to the subset linear index.
LinearSubIndexToFullIndex(int64 output_index,const gtl::InlinedVector<int32,8> & dim_list,const TensorShape & input_shape,const gtl::InlinedVector<int64,8> & strides)49 inline int64 LinearSubIndexToFullIndex(
50     int64 output_index, const gtl::InlinedVector<int32, 8>& dim_list,
51     const TensorShape& input_shape,
52     const gtl::InlinedVector<int64, 8>& strides) {
53   int64 result = 0;
54   int64 quotient = output_index;
55   for (int32 i = dim_list.size() - 1; i >= 0; --i) {
56     int32 dim = dim_list[i];
57     int64 dim_value = quotient % input_shape.dim_size(dim);
58     quotient = quotient / input_shape.dim_size(dim);
59     result += strides[dim] * dim_value;
60   }
61   return result;
62 }
63 
64 // Computes the number of input elements reduced per output element.
GetReductionIterSize(const gtl::InlinedVector<int32,8> & reduced_indices,const TensorShape & input_shape)65 int64 GetReductionIterSize(const gtl::InlinedVector<int32, 8>& reduced_indices,
66                            const TensorShape& input_shape) {
67   int64 result = 1;
68   for (int32 reduce_dim : reduced_indices) {
69     result *= input_shape.dim_size(reduce_dim);
70   }
71   return result;
72 }
73 
74 // Computes a list of all true reduced indices, accounting for negative
75 // indices.
GetReducedIndices(const Tensor & reduction_indices,int32 input_dims)76 gtl::InlinedVector<int32, 8> GetReducedIndices(const Tensor& reduction_indices,
77                                                int32 input_dims) {
78   const auto reduction_indices_flat = reduction_indices.flat<int32>();
79   const int32 reduction_dims = reduction_indices_flat.size();
80 
81   gtl::InlinedVector<int32, 8> reduced_indices(reduction_dims);
82   for (int32 i = 0; i < reduction_dims; ++i) {
83     reduced_indices[i] = reduction_indices_flat(reduction_dims - i - 1);
84     reduced_indices[i] += reduced_indices[i] < 0 ? input_dims : 0;
85   }
86 
87   return reduced_indices;
88 }
89 
90 // Appends all unreduced dimensions to the given vector.
MakeUnreducedIndices(gtl::InlinedVector<bool,8> index_is_reduced,int32 input_dims,gtl::InlinedVector<int32,8> * unreduced_indices)91 void MakeUnreducedIndices(gtl::InlinedVector<bool, 8> index_is_reduced,
92                           int32 input_dims,
93                           gtl::InlinedVector<int32, 8>* unreduced_indices) {
94   for (int32 index = 0; index < input_dims; ++index) {
95     if (!index_is_reduced[index]) unreduced_indices->push_back(index);
96   }
97 }
98 
GetOutputShape(gtl::InlinedVector<bool,8> index_is_reduced,const TensorShape & input_shape,bool keep_dims)99 TensorShape GetOutputShape(gtl::InlinedVector<bool, 8> index_is_reduced,
100                            const TensorShape& input_shape, bool keep_dims) {
101   TensorShape output_shape;
102   for (size_t index = 0; index < index_is_reduced.size(); ++index) {
103     if (index_is_reduced[index]) {
104       if (keep_dims) output_shape.AddDim(1);
105     } else {
106       output_shape.AddDim(input_shape.dim_size(index));
107     }
108   }
109   return output_shape;
110 }
111 
112 }  // namespace
113 
114 class ReduceJoinOp : public OpKernel {
115  public:
116   using OpKernel::OpKernel;
117 
ReduceJoinOp(OpKernelConstruction * ctx)118   explicit ReduceJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
119     OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
120     OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
121   }
122 
Compute(OpKernelContext * context)123   void Compute(OpKernelContext* context) override {
124     const Tensor& input = context->input(0);
125     const auto input_flat = input.flat<string>();
126     const TensorShape& input_shape = input.shape();
127     const int32 input_dims = input_shape.dims();
128 
129     const Tensor& reduction_indices = context->input(1);
130     const auto reduction_indices_flat = reduction_indices.flat<int32>();
131     const int32 reduction_dims = reduction_indices_flat.size();
132 
133     gtl::InlinedVector<bool, 8> index_is_reduced(input_dims, false);
134     for (int32 i = 0; i < reduction_dims; i++) {
135       int32 reduce_index = reduction_indices_flat(i);
136       const int32 true_reduce_index =
137           reduce_index < 0 ? reduce_index + input_dims : reduce_index;
138       OP_REQUIRES(
139           context, reduce_index >= -input_dims && reduce_index < input_dims,
140           errors::OutOfRange("Invalid reduction dimension ", reduce_index,
141                              " for input with ", input_dims, " dimension(s)"));
142       OP_REQUIRES(context, !index_is_reduced[true_reduce_index],
143                   errors::InvalidArgument("Duplicate reduction dimension ",
144                                           reduce_index));
145       index_is_reduced[true_reduce_index] = true;
146     }
147 
148     gtl::InlinedVector<int32, 8> reduced_indices =
149         GetReducedIndices(reduction_indices, input_dims);
150     gtl::InlinedVector<int32, 8> unreduced_indices;
151     MakeUnreducedIndices(index_is_reduced, input_dims, &unreduced_indices);
152     const auto strides = GetStrides(input_shape);
153 
154     Tensor* output_tensor = nullptr;
155     TensorShape output_shape =
156         GetOutputShape(index_is_reduced, input_shape, keep_dims_);
157     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
158                                                      &output_tensor));
159     auto output_flat = output_tensor->flat<string>();
160 
161     const int64 reduction_iter_size =
162         GetReductionIterSize(reduced_indices, input_shape);
163     gtl::InlinedVector<StringPiece, 8> curr_strings(reduction_iter_size);
164     for (int64 output_index = 0; output_index < output_shape.num_elements();
165          ++output_index) {
166       int64 output_full_index = LinearSubIndexToFullIndex(
167           output_index, unreduced_indices, input_shape, strides);
168       for (int64 reduction_index = 0; reduction_index < reduction_iter_size;
169            ++reduction_index) {
170         int64 reduction_full_index = LinearSubIndexToFullIndex(
171             reduction_index, reduced_indices, input_shape, strides);
172         curr_strings[reduction_index] =
173             input_flat(output_full_index + reduction_full_index);
174       }
175       output_flat(output_index) =
176           str_util::Join(curr_strings, separator_.c_str());
177     }
178   }
179 
180  private:
181   bool keep_dims_;
182   string separator_;
183 };
184 
185 REGISTER_KERNEL_BUILDER(Name("ReduceJoin").Device(DEVICE_CPU), ReduceJoinOp);
186 
187 }  // namespace tensorflow
188