1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <algorithm> 17 #include <locale> 18 #include <string> 19 20 #include "absl/strings/ascii.h" 21 #include "absl/strings/str_cat.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/op_requires.h" 24 #include "tensorflow/core/platform/errors.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 namespace text { 29 30 namespace { 31 template <typename SPLITS_TYPE> 32 class StringNGramsOp : public tensorflow::OpKernel { 33 public: StringNGramsOp(tensorflow::OpKernelConstruction * context)34 explicit StringNGramsOp(tensorflow::OpKernelConstruction* context) 35 : tensorflow::OpKernel(context) { 36 OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_)); 37 OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_)); 38 OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_)); 39 OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_)); 40 OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_)); 41 OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences", 42 &preserve_short_)); 43 } 44 get_pad_width(const int ngram_width) const45 int get_pad_width(const int ngram_width) const { 46 // Ngrams can be padded with either a fixed pad width or a dynamic pad 47 // width depending on the 'pad_width' arg, but in no case should the padding 48 // ever be wider than 'ngram_width' - 1. 49 return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_, 50 ngram_width - 1); 51 } 52 get_num_ngrams(const int length,const int ngram_width) const53 StatusOr<int> get_num_ngrams(const int length, const int ngram_width) const { 54 int64 limit = kint32max; 55 int pad_width = get_pad_width(ngram_width); 56 if (pad_width > limit / 2 - length) { 57 return errors::InvalidArgument( 58 "Pad width could lead to integer overflow, got pad_width = ", 59 pad_width); 60 } 61 return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1); 62 } 63 Compute(tensorflow::OpKernelContext * context)64 void Compute(tensorflow::OpKernelContext* context) override { 65 for (int ngram_width : ngram_widths_) { 66 OP_REQUIRES( 67 context, ngram_width > 0, 68 errors::InvalidArgument("ngram_widths must contain positive values")); 69 } 70 71 const tensorflow::Tensor* data; 72 OP_REQUIRES_OK(context, context->input("data", &data)); 73 const auto& input_data = data->flat<tstring>().data(); 74 75 const tensorflow::Tensor* splits; 76 OP_REQUIRES_OK(context, context->input("data_splits", &splits)); 77 const auto& splits_vec = splits->flat<SPLITS_TYPE>(); 78 79 // Validate that the splits are valid indices into data, only if there are 80 // splits specified. 81 const int input_data_size = data->flat<tstring>().size(); 82 const int splits_vec_size = splits_vec.size(); 83 if (splits_vec_size > 0) { 84 int prev_split = splits_vec(0); 85 OP_REQUIRES(context, prev_split == 0, 86 errors::InvalidArgument("First split value must be 0, got ", 87 prev_split)); 88 for (int i = 1; i < splits_vec_size; ++i) { 89 bool valid_splits = splits_vec(i) >= prev_split; 90 valid_splits = valid_splits && (splits_vec(i) <= input_data_size); 91 OP_REQUIRES(context, valid_splits, 92 errors::InvalidArgument( 93 "Invalid split value ", splits_vec(i), ", must be in [", 94 prev_split, ", ", input_data_size, "]")); 95 prev_split = splits_vec(i); 96 } 97 OP_REQUIRES(context, prev_split == input_data_size, 98 errors::InvalidArgument( 99 "Last split value must be data size. Expected ", 100 input_data_size, ", got ", prev_split)); 101 } 102 103 int num_batch_items = splits_vec.size() - 1; 104 tensorflow::Tensor* ngrams_splits; 105 OP_REQUIRES_OK( 106 context, context->allocate_output(1, splits->shape(), &ngrams_splits)); 107 auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data(); 108 109 // If there is no data or size, return an empty RT. 110 if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) { 111 tensorflow::Tensor* empty; 112 OP_REQUIRES_OK(context, 113 context->allocate_output(0, data->shape(), &empty)); 114 for (int i = 0; i <= num_batch_items; ++i) { 115 ngrams_splits_data[i] = 0; 116 } 117 return; 118 } 119 120 ngrams_splits_data[0] = 0; 121 for (int i = 1; i <= num_batch_items; ++i) { 122 int length = splits_vec(i) - splits_vec(i - 1); 123 int num_ngrams = 0; 124 for (int ngram_width : ngram_widths_) { 125 auto ngrams_or = get_num_ngrams(length, ngram_width); 126 OP_REQUIRES_OK(context, ngrams_or.status()); 127 num_ngrams += ngrams_or.ValueOrDie(); 128 } 129 if (preserve_short_ && length > 0 && num_ngrams == 0) { 130 num_ngrams = 1; 131 } 132 ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams; 133 } 134 135 tensorflow::Tensor* ngrams; 136 OP_REQUIRES_OK( 137 context, 138 context->allocate_output( 139 0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams)); 140 auto ngrams_data = ngrams->flat<tstring>().data(); 141 142 for (int i = 0; i < num_batch_items; ++i) { 143 auto data_start = &input_data[splits_vec(i)]; 144 int output_start_idx = ngrams_splits_data[i]; 145 for (int ngram_width : ngram_widths_) { 146 auto output_start = &ngrams_data[output_start_idx]; 147 int length = splits_vec(i + 1) - splits_vec(i); 148 auto ngrams_or = get_num_ngrams(length, ngram_width); 149 OP_REQUIRES_OK(context, ngrams_or.status()); 150 int num_ngrams = ngrams_or.ValueOrDie(); 151 CreateNgrams(data_start, output_start, num_ngrams, ngram_width); 152 output_start_idx += num_ngrams; 153 } 154 // If we're preserving short sequences, check to see if no sequence was 155 // generated by comparing the current output start idx to the original 156 // one (ngram_splits_data). If no ngrams were generated, then they will 157 // be equal (since we increment output_start_idx by num_ngrams every 158 // time we create a set of ngrams.) 159 if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) { 160 int data_length = splits_vec(i + 1) - splits_vec(i); 161 // One legitimate reason to not have any ngrams when preserve_short_ 162 // is true is if the sequence itself is empty. In that case, move on. 163 if (data_length == 0) { 164 continue; 165 } 166 // We don't have to worry about dynamic padding sizes here: if padding 167 // was dynamic, every sequence would have had sufficient padding to 168 // generate at least one ngram. 169 170 // If reached here, pad_width should be > 0, pad_width_ = -1, 171 // which indicates max(ngram_widths) - 1 cannot be used here since 172 // ngram_width is not known. 173 OP_REQUIRES( 174 context, pad_width_ >= 0, 175 errors::InvalidArgument("Pad width should be >= 0 when " 176 "preserve_short_sequences is True and " 177 "ngram_widths are not provided, got ", 178 pad_width_)); 179 int ngram_width = data_length + 2 * pad_width_; 180 auto output_start = &ngrams_data[output_start_idx]; 181 int num_ngrams = 1; 182 CreateNgrams(data_start, output_start, num_ngrams, ngram_width); 183 } 184 } 185 } 186 CreateNgrams(const tstring * data,tstring * output,int num_ngrams,int ngram_width) const187 void CreateNgrams(const tstring* data, tstring* output, int num_ngrams, 188 int ngram_width) const { 189 for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) { 190 int pad_width = get_pad_width(ngram_width); 191 int left_padding = std::max(0, pad_width - ngram_index); 192 int right_padding = 193 std::max(0, pad_width - (num_ngrams - (ngram_index + 1))); 194 int num_tokens = ngram_width - (left_padding + right_padding); 195 int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width; 196 197 // Calculate the total expected size of the ngram so we can reserve the 198 // correct amount of space in the string. 199 int ngram_size = 0; 200 // Size of the left padding. 201 ngram_size += left_padding * left_pad_.length(); 202 // Size of the tokens. 203 for (int n = 0; n < num_tokens; ++n) { 204 ngram_size += data[data_start_index + n].length(); 205 } 206 // Size of the right padding. 207 ngram_size += right_padding * right_pad_.length(); 208 // Size of the separators. 209 int num_separators = left_padding + right_padding + num_tokens - 1; 210 ngram_size += num_separators * separator_.length(); 211 212 // Build the ngram. 213 tstring* ngram = &output[ngram_index]; 214 ngram->reserve(ngram_size); 215 for (int n = 0; n < left_padding; ++n) { 216 ngram->append(left_pad_); 217 ngram->append(separator_); 218 } 219 // Only output first num_tokens - 1 pairs of data and separator 220 for (int n = 0; n < num_tokens - 1; ++n) { 221 ngram->append(data[data_start_index + n]); 222 ngram->append(separator_); 223 } 224 // Handle case when there are no tokens or no right padding as these can 225 // result in consecutive separators. 226 if (num_tokens > 0) { 227 // If we have tokens, then output last and then pair each separator with 228 // the right padding that follows, to ensure ngram ends either with the 229 // token or with the right pad. 230 ngram->append(data[data_start_index + num_tokens - 1]); 231 for (int n = 0; n < right_padding; ++n) { 232 ngram->append(separator_); 233 ngram->append(right_pad_); 234 } 235 } else { 236 // If we don't have tokens, then the last item inserted into the ngram 237 // has been the separator from the left padding loop above. Hence, 238 // output right pad and separator and make sure to finish with a 239 // padding, not a separator. 240 for (int n = 0; n < right_padding - 1; ++n) { 241 ngram->append(right_pad_); 242 ngram->append(separator_); 243 } 244 ngram->append(right_pad_); 245 } 246 247 // In debug mode only: validate that we've reserved enough space for the 248 // ngram. 249 DCHECK_EQ(ngram_size, ngram->size()); 250 } 251 } 252 253 string separator_; 254 string left_pad_; 255 string right_pad_; 256 bool use_pad_; 257 bool extend_pad_; 258 bool preserve_short_; 259 260 std::vector<int> ngram_widths_; 261 int pad_width_; 262 }; 263 264 } // namespace 265 REGISTER_KERNEL_BUILDER(Name("StringNGrams") 266 .Device(tensorflow::DEVICE_CPU) 267 .TypeConstraint<int32>("Tsplits"), 268 StringNGramsOp<int32>); 269 REGISTER_KERNEL_BUILDER(Name("StringNGrams") 270 .Device(tensorflow::DEVICE_CPU) 271 .TypeConstraint<int64_t>("Tsplits"), 272 StringNGramsOp<int64_t>); 273 274 } // namespace text 275 } // namespace tensorflow 276