• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/string_ops.cc.
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/framework/kernel_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/inlined_vector.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 
31 namespace tensorflow {
32 
33 namespace {
34 
35 template <typename INDICES_TYPE>
GetFlattenedRelativeOffsets(INDICES_TYPE small_stride,INDICES_TYPE big_stride)36 gtl::InlinedVector<INDICES_TYPE, 8> GetFlattenedRelativeOffsets(
37     INDICES_TYPE small_stride, INDICES_TYPE big_stride) {
38   gtl::InlinedVector<INDICES_TYPE, 8> flattened_offsets(small_stride);
39   for (auto i = 0; i < small_stride; i++) {
40     flattened_offsets[i] = i * big_stride;
41   }
42   return flattened_offsets;
43 }
44 
45 template <typename INDICES_TYPE>
GetStrides(const TensorShape & input_shape,const TensorShape & segment_id_shape)46 std::pair<INDICES_TYPE, INDICES_TYPE> GetStrides(
47     const TensorShape& input_shape, const TensorShape& segment_id_shape) {
48   int64_t small_stride = 1;
49   int64_t big_stride = 1;
50   for (auto i = 0; i < input_shape.dims(); i++) {
51     if (i < segment_id_shape.dims()) {
52       small_stride *= segment_id_shape.dim_size(i);
53     } else {
54       big_stride *= input_shape.dim_size(i);
55     }
56   }
57   return std::make_pair(big_stride, small_stride);
58 }
59 
GetOutputShape(const TensorShape & input_shape,const TensorShape & segment_id_shape,const int64_t num_segments)60 TensorShape GetOutputShape(const TensorShape& input_shape,
61                            const TensorShape& segment_id_shape,
62                            const int64_t num_segments) {
63   TensorShape output_shape;
64   output_shape.AddDim(num_segments);
65   for (size_t index = segment_id_shape.dims(); index < input_shape.dims();
66        ++index) {
67     output_shape.AddDim(input_shape.dim_size(index));
68   }
69   return output_shape;
70 }
71 
72 }  // namespace
73 
74 template <typename INDICES_TYPE, typename NUM_SEGMENTS_TYPE>
75 class UnsortedSegmentJoinOp : public OpKernel {
76  public:
77   using OpKernel::OpKernel;
78 
UnsortedSegmentJoinOp(OpKernelConstruction * ctx)79   explicit UnsortedSegmentJoinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
80     OP_REQUIRES_OK(ctx, ctx->GetAttr("separator", &separator_));
81   }
82 
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {
84     const Tensor& input = context->input(0);
85     const TensorShape& input_shape = input.shape();
86     const int32_t input_dims = input_shape.dims();
87 
88     const Tensor& segment_id = context->input(1);
89     const TensorShape& segment_id_shape = segment_id.shape();
90     const int32_t segment_dims = segment_id_shape.dims();
91 
92     const Tensor& num_segments_tensor = context->input(2);
93     OP_REQUIRES(context, num_segments_tensor.NumElements() != 0,
94                 errors::InvalidArgument("Number of segments cannot be empty."));
95     OP_REQUIRES(context,
96                 TensorShapeUtils::IsScalar(num_segments_tensor.shape()),
97                 errors::InvalidArgument("Number of segments must be a scalar"));
98     auto num_segments = num_segments_tensor.scalar<NUM_SEGMENTS_TYPE>()();
99 
100     OP_REQUIRES(
101         context, num_segments >= 0,
102         errors::InvalidArgument(
103             "Number of segments must be non-negative but got ", num_segments));
104     OP_REQUIRES(context, segment_dims != 0,
105                 errors::InvalidArgument("Segment_id cannot have rank 0"));
106 
107     OP_REQUIRES(
108         context, segment_dims <= input_dims,
109         errors::OutOfRange("Invalid segment_id rank ", segment_dims,
110                            " for input with ", input_dims, " dimension(s)"));
111     for (auto i = 0; i < segment_dims; i++) {
112       OP_REQUIRES(
113           context, segment_id_shape.dim_size(i) == input_shape.dim_size(i),
114           errors::InvalidArgument(
115               "Segment dimension is ", segment_id_shape.dim_size(i),
116               " while input dimension is ", input_dims, " in rank ", i));
117     }
118 
119     // Making output tensor.
120     Tensor* output_tensor = nullptr;
121     TensorShape output_shape =
122         GetOutputShape(input_shape, segment_id_shape, num_segments);
123     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
124                                                      &output_tensor));
125 
126     // Preparating flat tensors.
127     auto output_flat = output_tensor->flat<tstring>();
128     auto flat_segment_id = segment_id.flat<INDICES_TYPE>();
129     auto flat_input = input.flat<tstring>();
130 
131     for (int i = 0; i < flat_segment_id.size(); i++) {
132       OP_REQUIRES(
133           context,
134           ((flat_segment_id(i) < num_segments) && (flat_segment_id(i) >= 0)),
135           errors::InvalidArgument(
136               "segment_ids are not allowed to exceed num_segments or"
137               " to have negative values."));
138     }
139 
140     int64_t big_stride;
141     int64_t small_stride;
142     std::tie(big_stride, small_stride) =
143         GetStrides<INDICES_TYPE>(input_shape, segment_id_shape);
144     auto relative_offset_set =
145         GetFlattenedRelativeOffsets<INDICES_TYPE>(small_stride, big_stride);
146     for (auto start_offset = 0; start_offset < big_stride; start_offset++) {
147       for (auto i = 0; i < relative_offset_set.size(); i++) {
148         auto output_index = start_offset + flat_segment_id(i) * big_stride;
149         auto offset = start_offset + relative_offset_set[i];
150         if (output_flat(output_index).length() != 0)
151           output_flat(output_index).append(separator_.c_str());
152         output_flat(output_index).append(flat_input(offset));
153       }
154     }
155   }
156 
157  private:
158   string separator_;
159 };
160 
161 #define REGISTER_CPU_KERNEL(indices_type, num_segments_type)  \
162   REGISTER_KERNEL_BUILDER(                                    \
163       Name("UnsortedSegmentJoin")                             \
164           .Device(DEVICE_CPU)                                 \
165           .TypeConstraint<indices_type>("Tindices")           \
166           .TypeConstraint<num_segments_type>("Tnumsegments"), \
167       UnsortedSegmentJoinOp<indices_type, num_segments_type>);
168 
169 REGISTER_CPU_KERNEL(int32, int32);
170 REGISTER_CPU_KERNEL(int32, int64_t);
171 REGISTER_CPU_KERNEL(int64_t, int32);
172 REGISTER_CPU_KERNEL(int64_t, int64_t);
173 #undef REGISTER_CPU_KERNEL
174 
175 }  // namespace tensorflow
176