• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <locale>
18 #include <string>
19 
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 namespace text {
29 
30 namespace {
31 template <typename SPLITS_TYPE>
32 class StringNGramsOp : public tensorflow::OpKernel {
33  public:
StringNGramsOp(tensorflow::OpKernelConstruction * context)34   explicit StringNGramsOp(tensorflow::OpKernelConstruction* context)
35       : tensorflow::OpKernel(context) {
36     OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_));
37     OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_));
38     OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_));
39     OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_));
40     OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_));
41     OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences",
42                                              &preserve_short_));
43   }
44 
get_pad_width(const int ngram_width) const45   int get_pad_width(const int ngram_width) const {
46     // Ngrams can be padded with either a fixed pad width or a dynamic pad
47     // width depending on the 'pad_width' arg, but in no case should the padding
48     // ever be wider than 'ngram_width' - 1.
49     return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_,
50                     ngram_width - 1);
51   }
52 
get_num_ngrams(const int length,const int ngram_width) const53   StatusOr<int> get_num_ngrams(const int length, const int ngram_width) const {
54     int64 limit = kint32max;
55     int pad_width = get_pad_width(ngram_width);
56     if (pad_width > limit / 2 - length) {
57       return errors::InvalidArgument(
58           "Pad width could lead to integer overflow, got pad_width = ",
59           pad_width);
60     }
61     return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1);
62   }
63 
Compute(tensorflow::OpKernelContext * context)64   void Compute(tensorflow::OpKernelContext* context) override {
65     for (int ngram_width : ngram_widths_) {
66       OP_REQUIRES(
67           context, ngram_width > 0,
68           errors::InvalidArgument("ngram_widths must contain positive values"));
69     }
70 
71     const tensorflow::Tensor* data;
72     OP_REQUIRES_OK(context, context->input("data", &data));
73     const auto& input_data = data->flat<tstring>().data();
74 
75     const tensorflow::Tensor* splits;
76     OP_REQUIRES_OK(context, context->input("data_splits", &splits));
77     const auto& splits_vec = splits->flat<SPLITS_TYPE>();
78 
79     // Validate that the splits are valid indices into data, only if there are
80     // splits specified.
81     const int input_data_size = data->flat<tstring>().size();
82     const int splits_vec_size = splits_vec.size();
83     if (splits_vec_size > 0) {
84       int prev_split = splits_vec(0);
85       OP_REQUIRES(context, prev_split == 0,
86                   errors::InvalidArgument("First split value must be 0, got ",
87                                           prev_split));
88       for (int i = 1; i < splits_vec_size; ++i) {
89         bool valid_splits = splits_vec(i) >= prev_split;
90         valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
91         OP_REQUIRES(context, valid_splits,
92                     errors::InvalidArgument(
93                         "Invalid split value ", splits_vec(i), ", must be in [",
94                         prev_split, ", ", input_data_size, "]"));
95         prev_split = splits_vec(i);
96       }
97       OP_REQUIRES(context, prev_split == input_data_size,
98                   errors::InvalidArgument(
99                       "Last split value must be data size. Expected ",
100                       input_data_size, ", got ", prev_split));
101     }
102 
103     int num_batch_items = splits_vec.size() - 1;
104     tensorflow::Tensor* ngrams_splits;
105     OP_REQUIRES_OK(
106         context, context->allocate_output(1, splits->shape(), &ngrams_splits));
107     auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
108 
109     // If there is no data or size, return an empty RT.
110     if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
111       tensorflow::Tensor* empty;
112       OP_REQUIRES_OK(context,
113                      context->allocate_output(0, data->shape(), &empty));
114       for (int i = 0; i <= num_batch_items; ++i) {
115         ngrams_splits_data[i] = 0;
116       }
117       return;
118     }
119 
120     ngrams_splits_data[0] = 0;
121     for (int i = 1; i <= num_batch_items; ++i) {
122       int length = splits_vec(i) - splits_vec(i - 1);
123       int num_ngrams = 0;
124       for (int ngram_width : ngram_widths_) {
125         auto ngrams_or = get_num_ngrams(length, ngram_width);
126         OP_REQUIRES_OK(context, ngrams_or.status());
127         num_ngrams += ngrams_or.ValueOrDie();
128       }
129       if (preserve_short_ && length > 0 && num_ngrams == 0) {
130         num_ngrams = 1;
131       }
132       ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams;
133     }
134 
135     tensorflow::Tensor* ngrams;
136     OP_REQUIRES_OK(
137         context,
138         context->allocate_output(
139             0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams));
140     auto ngrams_data = ngrams->flat<tstring>().data();
141 
142     for (int i = 0; i < num_batch_items; ++i) {
143       auto data_start = &input_data[splits_vec(i)];
144       int output_start_idx = ngrams_splits_data[i];
145       for (int ngram_width : ngram_widths_) {
146         auto output_start = &ngrams_data[output_start_idx];
147         int length = splits_vec(i + 1) - splits_vec(i);
148         auto ngrams_or = get_num_ngrams(length, ngram_width);
149         OP_REQUIRES_OK(context, ngrams_or.status());
150         int num_ngrams = ngrams_or.ValueOrDie();
151         CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
152         output_start_idx += num_ngrams;
153       }
154       // If we're preserving short sequences, check to see if no sequence was
155       // generated by comparing the current output start idx to the original
156       // one (ngram_splits_data). If no ngrams were generated, then they will
157       // be equal (since we increment output_start_idx by num_ngrams every
158       // time we create a set of ngrams.)
159       if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) {
160         int data_length = splits_vec(i + 1) - splits_vec(i);
161         // One legitimate reason to not have any ngrams when preserve_short_
162         // is true is if the sequence itself is empty. In that case, move on.
163         if (data_length == 0) {
164           continue;
165         }
166         // We don't have to worry about dynamic padding sizes here: if padding
167         // was dynamic, every sequence would have had sufficient padding to
168         // generate at least one ngram.
169 
170         // If reached here, pad_width should be > 0, pad_width_ = -1,
171         // which indicates max(ngram_widths) - 1 cannot be used here since
172         // ngram_width is not known.
173         OP_REQUIRES(
174             context, pad_width_ >= 0,
175             errors::InvalidArgument("Pad width should be >= 0 when "
176                                     "preserve_short_sequences is True and "
177                                     "ngram_widths are not provided, got ",
178                                     pad_width_));
179         int ngram_width = data_length + 2 * pad_width_;
180         auto output_start = &ngrams_data[output_start_idx];
181         int num_ngrams = 1;
182         CreateNgrams(data_start, output_start, num_ngrams, ngram_width);
183       }
184     }
185   }
186 
CreateNgrams(const tstring * data,tstring * output,int num_ngrams,int ngram_width) const187   void CreateNgrams(const tstring* data, tstring* output, int num_ngrams,
188                     int ngram_width) const {
189     for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) {
190       int pad_width = get_pad_width(ngram_width);
191       int left_padding = std::max(0, pad_width - ngram_index);
192       int right_padding =
193           std::max(0, pad_width - (num_ngrams - (ngram_index + 1)));
194       int num_tokens = ngram_width - (left_padding + right_padding);
195       int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width;
196 
197       // Calculate the total expected size of the ngram so we can reserve the
198       // correct amount of space in the string.
199       int ngram_size = 0;
200       // Size of the left padding.
201       ngram_size += left_padding * left_pad_.length();
202       // Size of the tokens.
203       for (int n = 0; n < num_tokens; ++n) {
204         ngram_size += data[data_start_index + n].length();
205       }
206       // Size of the right padding.
207       ngram_size += right_padding * right_pad_.length();
208       // Size of the separators.
209       int num_separators = left_padding + right_padding + num_tokens - 1;
210       ngram_size += num_separators * separator_.length();
211 
212       // Build the ngram.
213       tstring* ngram = &output[ngram_index];
214       ngram->reserve(ngram_size);
215       for (int n = 0; n < left_padding; ++n) {
216         ngram->append(left_pad_);
217         ngram->append(separator_);
218       }
219       // Only output first num_tokens - 1 pairs of data and separator
220       for (int n = 0; n < num_tokens - 1; ++n) {
221         ngram->append(data[data_start_index + n]);
222         ngram->append(separator_);
223       }
224       // Handle case when there are no tokens or no right padding as these can
225       // result in consecutive separators.
226       if (num_tokens > 0) {
227         // If we have tokens, then output last and then pair each separator with
228         // the right padding that follows, to ensure ngram ends either with the
229         // token or with the right pad.
230         ngram->append(data[data_start_index + num_tokens - 1]);
231         for (int n = 0; n < right_padding; ++n) {
232           ngram->append(separator_);
233           ngram->append(right_pad_);
234         }
235       } else {
236         // If we don't have tokens, then the last item inserted into the ngram
237         // has been the separator from the left padding loop above. Hence,
238         // output right pad and separator and make sure to finish with a
239         // padding, not a separator.
240         for (int n = 0; n < right_padding - 1; ++n) {
241           ngram->append(right_pad_);
242           ngram->append(separator_);
243         }
244         ngram->append(right_pad_);
245       }
246 
247       // In debug mode only: validate that we've reserved enough space for the
248       // ngram.
249       DCHECK_EQ(ngram_size, ngram->size());
250     }
251   }
252 
253   string separator_;
254   string left_pad_;
255   string right_pad_;
256   bool use_pad_;
257   bool extend_pad_;
258   bool preserve_short_;
259 
260   std::vector<int> ngram_widths_;
261   int pad_width_;
262 };
263 
264 }  // namespace
265 REGISTER_KERNEL_BUILDER(Name("StringNGrams")
266                             .Device(tensorflow::DEVICE_CPU)
267                             .TypeConstraint<int32>("Tsplits"),
268                         StringNGramsOp<int32>);
269 REGISTER_KERNEL_BUILDER(Name("StringNGrams")
270                             .Device(tensorflow::DEVICE_CPU)
271                             .TypeConstraint<int64_t>("Tsplits"),
272                         StringNGramsOp<int64_t>);
273 
274 }  // namespace text
275 }  // namespace tensorflow
276