1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <locale> 17 #include <string> 18 19 #include "absl/strings/ascii.h" 20 #include "absl/strings/str_cat.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/platform/errors.h" 23 24 namespace tensorflow { 25 namespace text { 26 27 namespace { 28 template <typename SPLITS_TYPE> 29 class StringNGramsOp : public tensorflow::OpKernel { 30 public: StringNGramsOp(tensorflow::OpKernelConstruction * context)31 explicit StringNGramsOp(tensorflow::OpKernelConstruction* context) 32 : tensorflow::OpKernel(context) { 33 OP_REQUIRES_OK(context, context->GetAttr("separator", &separator_)); 34 OP_REQUIRES_OK(context, context->GetAttr("ngram_widths", &ngram_widths_)); 35 OP_REQUIRES_OK(context, context->GetAttr("left_pad", &left_pad_)); 36 OP_REQUIRES_OK(context, context->GetAttr("right_pad", &right_pad_)); 37 OP_REQUIRES_OK(context, context->GetAttr("pad_width", &pad_width_)); 38 OP_REQUIRES_OK(context, context->GetAttr("preserve_short_sequences", 39 &preserve_short_)); 40 } 41 get_pad_width(const int ngram_width) const42 int get_pad_width(const int ngram_width) const { 43 // Ngrams can be padded with either a fixed pad width or a dynamic pad 44 // width depending on the 'pad_width' arg, but in no case should the padding 45 // ever be wider than 'ngram_width' - 1. 46 return std::min(pad_width_ < 0 ? ngram_width - 1 : pad_width_, 47 ngram_width - 1); 48 } 49 get_num_ngrams(const int length,const int ngram_width) const50 int get_num_ngrams(const int length, const int ngram_width) const { 51 int pad_width = get_pad_width(ngram_width); 52 return std::max(0, ((length + 2 * pad_width) - ngram_width) + 1); 53 } 54 Compute(tensorflow::OpKernelContext * context)55 void Compute(tensorflow::OpKernelContext* context) override { 56 for (int ngram_width : ngram_widths_) { 57 OP_REQUIRES( 58 context, ngram_width > 0, 59 errors::InvalidArgument("ngram_widths must contain positive values")); 60 } 61 62 const tensorflow::Tensor* data; 63 OP_REQUIRES_OK(context, context->input("data", &data)); 64 const auto& input_data = data->flat<tstring>().data(); 65 66 const tensorflow::Tensor* splits; 67 OP_REQUIRES_OK(context, context->input("data_splits", &splits)); 68 const auto& splits_vec = splits->flat<SPLITS_TYPE>(); 69 70 // Validate that the splits are valid indices into data, only if there are 71 // splits specified. 72 const int input_data_size = data->flat<tstring>().size(); 73 const int splits_vec_size = splits_vec.size(); 74 if (splits_vec_size > 0) { 75 int prev_split = splits_vec(0); 76 OP_REQUIRES(context, prev_split == 0, 77 errors::InvalidArgument("First split value must be 0, got ", 78 prev_split)); 79 for (int i = 1; i < splits_vec_size; ++i) { 80 bool valid_splits = splits_vec(i) >= prev_split; 81 valid_splits = valid_splits && (splits_vec(i) <= input_data_size); 82 OP_REQUIRES(context, valid_splits, 83 errors::InvalidArgument( 84 "Invalid split value ", splits_vec(i), ", must be in [", 85 prev_split, ", ", input_data_size, "]")); 86 prev_split = splits_vec(i); 87 } 88 OP_REQUIRES(context, prev_split == input_data_size, 89 errors::InvalidArgument( 90 "Last split value must be data size. Expected ", 91 input_data_size, ", got ", prev_split)); 92 } 93 94 int num_batch_items = splits_vec.size() - 1; 95 tensorflow::Tensor* ngrams_splits; 96 OP_REQUIRES_OK( 97 context, context->allocate_output(1, splits->shape(), &ngrams_splits)); 98 auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data(); 99 100 // If there is no data or size, return an empty RT. 101 if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) { 102 tensorflow::Tensor* empty; 103 OP_REQUIRES_OK(context, 104 context->allocate_output(0, data->shape(), &empty)); 105 for (int i = 0; i <= num_batch_items; ++i) { 106 ngrams_splits_data[i] = 0; 107 } 108 return; 109 } 110 111 ngrams_splits_data[0] = 0; 112 for (int i = 1; i <= num_batch_items; ++i) { 113 int length = splits_vec(i) - splits_vec(i - 1); 114 int num_ngrams = 0; 115 for (int ngram_width : ngram_widths_) 116 num_ngrams += get_num_ngrams(length, ngram_width); 117 if (preserve_short_ && length > 0 && num_ngrams == 0) { 118 num_ngrams = 1; 119 } 120 ngrams_splits_data[i] = ngrams_splits_data[i - 1] + num_ngrams; 121 } 122 123 tensorflow::Tensor* ngrams; 124 OP_REQUIRES_OK( 125 context, 126 context->allocate_output( 127 0, TensorShape({ngrams_splits_data[num_batch_items]}), &ngrams)); 128 auto ngrams_data = ngrams->flat<tstring>().data(); 129 130 for (int i = 0; i < num_batch_items; ++i) { 131 auto data_start = &input_data[splits_vec(i)]; 132 int output_start_idx = ngrams_splits_data[i]; 133 for (int ngram_width : ngram_widths_) { 134 auto output_start = &ngrams_data[output_start_idx]; 135 int length = splits_vec(i + 1) - splits_vec(i); 136 int num_ngrams = get_num_ngrams(length, ngram_width); 137 CreateNgrams(data_start, output_start, num_ngrams, ngram_width); 138 output_start_idx += num_ngrams; 139 } 140 // If we're preserving short sequences, check to see if no sequence was 141 // generated by comparing the current output start idx to the original 142 // one (ngram_splits_data). If no ngrams were generated, then they will 143 // be equal (since we increment output_start_idx by num_ngrams every 144 // time we create a set of ngrams.) 145 if (preserve_short_ && output_start_idx == ngrams_splits_data[i]) { 146 int data_length = splits_vec(i + 1) - splits_vec(i); 147 // One legitimate reason to not have any ngrams when preserve_short_ 148 // is true is if the sequence itself is empty. In that case, move on. 149 if (data_length == 0) { 150 continue; 151 } 152 // We don't have to worry about dynamic padding sizes here: if padding 153 // was dynamic, every sequence would have had sufficient padding to 154 // generate at least one ngram. 155 int ngram_width = data_length + 2 * pad_width_; 156 auto output_start = &ngrams_data[output_start_idx]; 157 int num_ngrams = 1; 158 CreateNgrams(data_start, output_start, num_ngrams, ngram_width); 159 } 160 } 161 } 162 CreateNgrams(const tstring * data,tstring * output,int num_ngrams,int ngram_width) const163 void CreateNgrams(const tstring* data, tstring* output, int num_ngrams, 164 int ngram_width) const { 165 for (int ngram_index = 0; ngram_index < num_ngrams; ++ngram_index) { 166 int pad_width = get_pad_width(ngram_width); 167 int left_padding = std::max(0, pad_width - ngram_index); 168 int right_padding = 169 std::max(0, pad_width - (num_ngrams - (ngram_index + 1))); 170 int num_tokens = ngram_width - (left_padding + right_padding); 171 int data_start_index = left_padding > 0 ? 0 : ngram_index - pad_width; 172 173 // Calculate the total expected size of the ngram so we can reserve the 174 // correct amount of space in the string. 175 int ngram_size = 0; 176 // Size of the left padding. 177 ngram_size += left_padding * left_pad_.length(); 178 // Size of the tokens. 179 for (int n = 0; n < num_tokens; ++n) { 180 ngram_size += data[data_start_index + n].length(); 181 } 182 // Size of the right padding. 183 ngram_size += right_padding * right_pad_.length(); 184 // Size of the separators. 185 int num_separators = left_padding + right_padding + num_tokens - 1; 186 ngram_size += num_separators * separator_.length(); 187 188 // Build the ngram. 189 tstring* ngram = &output[ngram_index]; 190 ngram->reserve(ngram_size); 191 for (int n = 0; n < left_padding; ++n) { 192 ngram->append(left_pad_); 193 ngram->append(separator_); 194 } 195 // Only output first num_tokens - 1 pairs of data and separator 196 for (int n = 0; n < num_tokens - 1; ++n) { 197 ngram->append(data[data_start_index + n]); 198 ngram->append(separator_); 199 } 200 // Handle case when there are no tokens or no right padding as these can 201 // result in consecutive separators. 202 if (num_tokens > 0) { 203 // If we have tokens, then output last and then pair each separator with 204 // the right padding that follows, to ensure ngram ends either with the 205 // token or with the right pad. 206 ngram->append(data[data_start_index + num_tokens - 1]); 207 for (int n = 0; n < right_padding; ++n) { 208 ngram->append(separator_); 209 ngram->append(right_pad_); 210 } 211 } else { 212 // If we don't have tokens, then the last item inserted into the ngram 213 // has been the separator from the left padding loop above. Hence, 214 // output right pad and separator and make sure to finish with a 215 // padding, not a separator. 216 for (int n = 0; n < right_padding - 1; ++n) { 217 ngram->append(right_pad_); 218 ngram->append(separator_); 219 } 220 ngram->append(right_pad_); 221 } 222 223 // In debug mode only: validate that we've reserved enough space for the 224 // ngram. 225 DCHECK_EQ(ngram_size, ngram->size()); 226 } 227 } 228 229 string separator_; 230 string left_pad_; 231 string right_pad_; 232 bool use_pad_; 233 bool extend_pad_; 234 bool preserve_short_; 235 236 std::vector<int> ngram_widths_; 237 int pad_width_; 238 }; 239 240 } // namespace 241 REGISTER_KERNEL_BUILDER(Name("StringNGrams") 242 .Device(tensorflow::DEVICE_CPU) 243 .TypeConstraint<int32>("Tsplits"), 244 StringNGramsOp<int32>); 245 REGISTER_KERNEL_BUILDER(Name("StringNGrams") 246 .Device(tensorflow::DEVICE_CPU) 247 .TypeConstraint<int64>("Tsplits"), 248 StringNGramsOp<int64>); 249 250 } // namespace text 251 } // namespace tensorflow 252