• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 #include <cstddef>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "unicode/appendable.h"  // TF:icu
25 #include "unicode/schriter.h"  // TF:icu
26 #include "unicode/uchar.h"  // TF:icu
27 #include "unicode/ucnv.h"  // TF:icu
28 #include "unicode/ucnv_err.h"  // TF:icu
29 #include "unicode/umachine.h"  // TF:icu
30 #include "unicode/uniset.h"  // TF:icu
31 #include "unicode/unistr.h"  // TF:icu
32 #include "unicode/uset.h"  // TF:icu
33 #include "unicode/utypes.h"  // TF:icu
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/kernel_def_builder.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/register_types.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_types.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/kernels/string_util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/util/bcast.h"
49 #include "tensorflow/core/util/ptr_util.h"
50 
51 namespace tensorflow {
52 namespace {
53 
Encode(const UnicodeEncoding encoding,const icu::UnicodeString & in,string * out)54 void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in,
55             string* out) {
56   if (encoding == UnicodeEncoding::UTF8) {
57     out->clear();
58     in.toUTF8String(*out);
59   } else if (encoding == UnicodeEncoding::UTF16BE) {
60     // TODO(gbillock): consider using the
61     // extract(char *dest, int32_t destCapacity, UConverter *cnv)
62     // for UTF16/32
63     out->clear();  // subtle: must come before reserve()
64     out->reserve(2 * in.length() + 1);
65     const char16_t* buf = in.getBuffer();
66     for (int i = 0; i < in.length(); ++i) {
67       // Emit big-endian encoding for UTF-16 always.
68       out->push_back((buf[i] & 0xFF00) >> 8);
69       out->push_back(buf[i] & 0x00FF);
70     }
71   } else if (encoding == UnicodeEncoding::UTF32BE) {
72     out->clear();  // subtle: must come before reserve()
73     out->reserve(4 * in.countChar32() + 1);
74     icu::StringCharacterIterator it(in);
75     UChar32 ch;
76     while (it.hasNext()) {
77       ch = it.next32PostInc();
78       out->push_back((ch & 0xFF000000) >> 24);
79       out->push_back((ch & 0x00FF0000) >> 16);
80       out->push_back((ch & 0x0000FF00) >> 8);
81       out->push_back((ch & 0x000000FF));
82     }
83   }
84 }
85 
86 // This error callback is only useful for finding illegal encoding errors when
87 // we want to be strict -- otherwise illegal encodings are replaced on read
88 // with 0xFFFD and signaled to the callback.
unicode_error_callback(const void * context,UConverterToUnicodeArgs * args,const char * codeUnits,int32_t length,UConverterCallbackReason reason,UErrorCode * pErrorCode)89 void unicode_error_callback(const void* context, UConverterToUnicodeArgs* args,
90                             const char* codeUnits, int32_t length,
91                             UConverterCallbackReason reason,
92                             UErrorCode* pErrorCode) {
93   // Careful: this depends on setting up the context settings when the
94   // callback is registered.
95   bool* format_error = const_cast<bool*>(static_cast<const bool*>(context));
96 
97   if (reason == UCNV_UNASSIGNED || reason == UCNV_ILLEGAL ||
98       reason == UCNV_IRREGULAR) {
99     *format_error = true;
100   }
101 
102   // Side note: the default behavior in this case is that without a substitution
103   // made by the callback, the UConverter will signal an error to the iterator
104   // making the string iteration bail out. Instead, forward to the built-in
105   // substitution handler.
106   UCNV_TO_U_CALLBACK_SUBSTITUTE(nullptr, args, codeUnits, length, reason,
107                                 pErrorCode);
108 }
109 
110 // Iterates through a source string given the provided input UConverter specific
111 // to the encoding for that string. Calls a provided callback for each codepoint
112 // consumed. Provides the callback with the codepoint and the number of bytes
113 // consumed from the input string to produce it. If there are invalid encoding
114 // loci in the source string, they will be provided as a 0xFFFD codepoint to
115 // the callback, unless the "fail_on_formatting_error" arg is set, in which
116 // case the callback will be passed the signal that there is such an invalid
117 // encoding position.
118 // callback: function(UChar32 codepoint, int num_bytes_consumed_from_source_str,
119 //                    bool fatal_format_error)
IterateUnicodeString(const string & str,UConverter * converter,std::function<void (UChar32,int,bool)> callback)120 void IterateUnicodeString(const string& str, UConverter* converter,
121                           std::function<void(UChar32, int, bool)> callback) {
122   const char* source = str.data();
123   const char* limit = str.data() + str.length();
124   UErrorCode status = U_ZERO_ERROR;
125 
126   UConverterToUCallback oldAction = nullptr;
127   const void* oldContext = nullptr;
128   bool format_error = false;
129 
130   // Subtle. You can't make a function pointer from a std::function. :-(
131   // Instead, we pass the boolean pointer as the "context" object.
132   ucnv_setToUCallBack(converter, unicode_error_callback, &format_error,
133                       &oldAction, &oldContext, &status);
134   if (U_FAILURE(status)) {
135     LOG(ERROR) << "Could not set unicode error callback on converter";
136     return;
137   }
138 
139   while (source < limit) {
140     const char* source_pre_fetch = source;
141     // Note: ucnv_getNextUChar returns 0xFFFD on an encoding error.
142     UChar32 next_char = ucnv_getNextUChar(converter, &source, limit, &status);
143     if (U_FAILURE(status)) {
144       source = limit;
145     }
146     int bytes_consumed = source - source_pre_fetch;
147     callback(next_char, bytes_consumed, format_error);
148     format_error = false;
149   }
150 
151   ucnv_setToUCallBack(converter, oldAction, oldContext, nullptr, nullptr,
152                       &status);
153 }
154 
155 // Lifecycle wrapper for UConverter making it easier to use with thread_local.
156 // TODO(gbillock): Consider whether to use the higher-level convert API and
157 // create a specialized fast code path for UTF8.
158 class WrappedConverter {
159  public:
WrappedConverter()160   WrappedConverter() {}
161 
~WrappedConverter()162   ~WrappedConverter() {
163     if (converter_) {
164       ucnv_close(converter_);
165     }
166   }
167 
init(const string & name)168   void init(const string& name) {
169     if (converter_ && name == name_) {
170       // Note: this reset is not typically needed, but if not done, then in some
171       // cases the cached converter will maintain state of input endianness
172       // which isn't valid from input to input in every batched case.
173       ucnv_reset(converter_);
174       return;
175     }
176 
177     if (converter_) {
178       ucnv_close(converter_);
179       converter_ = nullptr;
180       name_ = "";
181     }
182 
183     UErrorCode status = U_ZERO_ERROR;
184     converter_ = ucnv_open(name.c_str(), &status);
185     if (U_FAILURE(status)) {
186       if (converter_) {
187         ucnv_close(converter_);
188         converter_ = nullptr;
189       }
190     } else {
191       name_ = name;
192     }
193   }
194 
195   UConverter* converter_ = nullptr;
196   string name_;
197 };
198 
199 struct ErrorOptions {
200   UChar32 subst = 0xFFFD;
201   bool elide_replacement = false;
202   bool replace_control_chars = false;
203   bool error_on_malformatting = false;
204 };
205 
GetErrorOptions(OpKernelConstruction * ctx,ErrorOptions * out)206 Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) {
207   *out = ErrorOptions();
208 
209   string error_policy;
210   TF_RETURN_IF_ERROR(ctx->GetAttr("errors", &error_policy));
211 
212   if (error_policy == "replace") {
213     out->elide_replacement = false;
214   } else if (error_policy == "ignore") {
215     out->elide_replacement = true;
216   } else if (error_policy == "strict") {
217     out->error_on_malformatting = true;
218   } else {
219     return errors::InvalidArgument(
220         "errors policy must be one of 'strict', 'replace', or 'ignore'");
221   }
222 
223   int32 replacement_char;
224   TF_RETURN_IF_ERROR(ctx->GetAttr("replacement_char", &replacement_char));
225 
226   if (replacement_char >= UCHAR_MIN_VALUE &&
227       replacement_char <= UCHAR_MAX_VALUE) {
228     out->subst = replacement_char;
229   } else {
230     return errors::InvalidArgument(
231         "replacement_char out of unicode codepoint range");
232   }
233 
234   if (ctx->HasAttr("replace_control_characters")) {
235     TF_RETURN_IF_ERROR(ctx->GetAttr("replace_control_characters",
236                                     &(out->replace_control_chars)));
237   }
238 
239   return Status::OK();
240 }
241 
ShouldHandleFormatError(const ErrorOptions & error_options,UChar32 ch,bool format_error)242 inline bool ShouldHandleFormatError(const ErrorOptions& error_options,
243                                     UChar32 ch, bool format_error) {
244   return ((error_options.replace_control_chars && ch <= 0x1F) || format_error);
245 }
246 
247 }  // namespace
248 
249 class UnicodeTranscodeOp : public OpKernel {
250  public:
UnicodeTranscodeOp(OpKernelConstruction * ctx)251   explicit UnicodeTranscodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
252     OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
253 
254     string output_encoding;
255     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &output_encoding));
256     OP_REQUIRES_OK(ctx,
257                    ParseUnicodeEncoding(output_encoding, &output_encoding_));
258 
259     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_));
260     // Make a temporary UConverter to ensure it will create without error
261     // at execution time (and to warm any data caches the converter needs).
262     // This instance is not used.
263     std::unique_ptr<WrappedConverter> input_encoder =
264         absl::make_unique<WrappedConverter>();
265     input_encoder->init(input_encoding_);
266     OP_REQUIRES(ctx, input_encoder->converter_,
267                 errors::InvalidArgument(
268                     "Could not create converter for input encoding: " +
269                     input_encoding_));
270   }
271 
Compute(OpKernelContext * ctx)272   void Compute(OpKernelContext* ctx) override {
273     const Tensor* input_tensor;
274     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
275 
276     static thread_local std::unique_ptr<WrappedConverter> input_encoder;
277     if (!input_encoder) {
278       input_encoder.reset(new WrappedConverter());
279     }
280     input_encoder->init(input_encoding_);
281     OP_REQUIRES(ctx, input_encoder->converter_,
282                 errors::InvalidArgument(
283                     "Could not create converter for input encoding: " +
284                     input_encoding_));
285 
286     // Output may be forwardable from input, in which case work in-place.
287     Tensor* output_tensor;
288     std::unique_ptr<Tensor> maybe_forwarded =
289         ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
290                            tensorflow::DT_STRING, input_tensor->shape(),
291                            ctx->input_memory_type(0), ctx->input_alloc_attr(0));
292     if (maybe_forwarded) {
293       output_tensor = maybe_forwarded.get();
294       OP_REQUIRES_OK(ctx, ctx->set_output("output", *output_tensor));
295     } else {
296       OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
297                                                &output_tensor));
298       output_tensor->flat<string>() = input_tensor->flat<string>();
299     }
300 
301     auto output_flat = output_tensor->flat<string>();
302     bool found_any_format_error = false;
303     for (size_t i = 0; i < output_flat.size(); ++i) {
304       Transcode(&(output_flat(i)), input_encoder->converter_,
305                 &found_any_format_error);
306     }
307     if (error_options_.error_on_malformatting && found_any_format_error) {
308       ctx->CtxFailure(
309           errors::InvalidArgument("Invalid formatting on input string"));
310     }
311   }
312 
313  private:
314   // Consume a codepoint from the input string and add it to the buffer.
315   // This function takes care of any replacement configuration on invalid or
316   // out-of-range inputs.
TranslateCodepoints(icu::UnicodeString * s,bool * found_any_format_error,UChar32 ch,int src_bytes,bool format_error)317   void TranslateCodepoints(icu::UnicodeString* s, bool* found_any_format_error,
318                            UChar32 ch, int src_bytes, bool format_error) {
319     if (ShouldHandleFormatError(error_options_, ch, format_error)) {
320       *found_any_format_error = true;
321       if (error_options_.elide_replacement) {
322         return;
323       } else {
324         ch = error_options_.subst;
325       }
326     }
327     s->append(ch);
328   }
329 
330   // Transcode the string from input encoding to the output_encoding_. If
331   // non-valid characters are encountered, use the subst_/elide_replacement_
332   // config to handle them.
Transcode(string * s,UConverter * input_encoder,bool * found_any_format_error)333   void Transcode(string* s, UConverter* input_encoder,
334                  bool* found_any_format_error) {
335     icu::UnicodeString source;
336     IterateUnicodeString(
337         *s, input_encoder,
338         std::bind(&UnicodeTranscodeOp::TranslateCodepoints, this, &source,
339                   found_any_format_error, std::placeholders::_1,
340                   std::placeholders::_2, std::placeholders::_3));
341 
342     Encode(output_encoding_, source, s);
343   }
344 
345   string input_encoding_;
346   ErrorOptions error_options_;
347   UnicodeEncoding output_encoding_ = UnicodeEncoding::UTF8;
348 };
349 
350 REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU),
351                         UnicodeTranscodeOp);
352 
353 class UnicodeDecodeBaseOp : public OpKernel {
354  public:
UnicodeDecodeBaseOp(OpKernelConstruction * ctx,bool generate_offsets)355   explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets)
356       : OpKernel(ctx), generate_offsets_(generate_offsets) {
357     OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
358     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_));
359     // Make a temporary UConverter to ensure it will create without error
360     // at execution time (and to warm any data caches the converter needs).
361     // This instance is not used.
362     std::unique_ptr<WrappedConverter> input_encoder =
363         absl::make_unique<WrappedConverter>();
364     input_encoder->init(input_encoding_);
365     OP_REQUIRES(ctx, input_encoder->converter_,
366                 errors::InvalidArgument(
367                     "Could not create converter for input encoding: " +
368                     input_encoding_));
369   }
370 
Decode(OpKernelContext * ctx,std::vector<UChar32> * char_values,std::vector<int64> * offset_values,int * current_offset,int64 * next_row_split,UChar32 char_value,int char_length,bool found_any_format_error)371   void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values,
372               std::vector<int64>* offset_values, int* current_offset,
373               int64* next_row_split, UChar32 char_value, int char_length,
374               bool found_any_format_error) {
375     if (error_options_.error_on_malformatting && found_any_format_error) {
376       ctx->CtxFailure(
377           errors::InvalidArgument("Invalid formatting on input string"));
378     }
379     UChar32 decoded_value = char_value;
380     if (ShouldHandleFormatError(error_options_, char_value,
381                                 found_any_format_error)) {
382       if (error_options_.elide_replacement && (offset_values != nullptr)) {
383         *current_offset += char_length;
384         return;
385       } else {
386         decoded_value = error_options_.subst;
387       }
388     }
389 
390     // Emit the char value.
391     char_values->push_back(decoded_value);
392 
393     // Emit the byte offset
394     if (offset_values != nullptr) {
395       offset_values->push_back(*current_offset);
396       *current_offset += char_length;
397     }
398     *next_row_split += 1;
399   }
400 
Compute(OpKernelContext * ctx)401   void Compute(OpKernelContext* ctx) override {
402     const Tensor* input_tensor;
403     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
404 
405     // Go through all the strings in `input`.
406     const auto& input_vec = input_tensor->flat<string>();
407 
408     std::unique_ptr<WrappedConverter> input_encoder =
409         absl::make_unique<WrappedConverter>();
410     input_encoder->init(input_encoding_);
411     OP_REQUIRES(ctx, input_encoder->converter_,
412                 errors::InvalidArgument(
413                     "Could not create converter for input encoding: " +
414                     input_encoding_));
415 
416     std::vector<UChar32> char_values;
417     std::vector<int64> offset_values;
418 
419     Tensor* output_row_splits;
420     OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits",
421                                              {input_tensor->NumElements() + 1},
422                                              &output_row_splits));
423     auto out_row_splits = output_row_splits->vec<int64>();
424 
425     int row_split_index = 0;
426     int64 next_row_split = 0;
427     for (int i = 0; i < input_vec.size(); ++i) {
428       const string& input = input_vec(i);
429       // Convert input strings into unicode values. Output to a list of
430       // char_values, record row splits and char_to_byte_starts, which are all
431       // the fields needed to construct a RaggedTensor.
432       out_row_splits(row_split_index) = next_row_split;
433       row_split_index++;
434       int current_offset = 0;
435       IterateUnicodeString(
436           input, input_encoder->converter_,
437           std::bind(&UnicodeDecodeBaseOp::Decode, this, ctx, &char_values,
438                     &offset_values, &current_offset, &next_row_split,
439                     std::placeholders::_1, std::placeholders::_2,
440                     std::placeholders::_3));
441     }
442     out_row_splits(row_split_index) = next_row_split;
443 
444     Tensor* output_char_values;
445     OP_REQUIRES_OK(
446         ctx, ctx->allocate_output("char_values",
447                                   {static_cast<int64>(char_values.size())},
448                                   &output_char_values));
449     auto out_char_values = output_char_values->vec<int32>();
450     if (generate_offsets_) {
451       DCHECK(offset_values.size() == char_values.size());
452       Tensor* output_offset_values;
453       OP_REQUIRES_OK(
454           ctx, ctx->allocate_output("char_to_byte_starts",
455                                     {static_cast<int64>(offset_values.size())},
456                                     &output_offset_values));
457       auto out_offset_values = output_offset_values->vec<int64>();
458 
459       // Load output tensors from intermediate value arrays.
460       for (int i = 0; i < char_values.size(); ++i) {
461         out_char_values(i) = static_cast<int32>(char_values[i]);
462         out_offset_values(i) = offset_values[i];
463       }
464     } else {
465       for (int i = 0; i < char_values.size(); ++i) {
466         out_char_values(i) = static_cast<int32>(char_values[i]);
467       }
468     }
469   }
470 
471  private:
472   string input_encoding_;
473   ErrorOptions error_options_;
474   bool generate_offsets_ = false;
475 };
476 
477 class UnicodeDecodeOp : public UnicodeDecodeBaseOp {
478  public:
UnicodeDecodeOp(OpKernelConstruction * ctx)479   explicit UnicodeDecodeOp(OpKernelConstruction* ctx)
480       : UnicodeDecodeBaseOp(ctx, false) {}
481 };
482 
483 class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp {
484  public:
UnicodeDecodeWithOffsetsOp(OpKernelConstruction * ctx)485   explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx)
486       : UnicodeDecodeBaseOp(ctx, true) {}
487 };
488 
489 REGISTER_KERNEL_BUILDER(Name("UnicodeDecode").Device(DEVICE_CPU),
490                         UnicodeDecodeOp);
491 REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets").Device(DEVICE_CPU),
492                         UnicodeDecodeWithOffsetsOp);
493 
494 class UnicodeEncodeOp : public OpKernel {
495  public:
UnicodeEncodeOp(OpKernelConstruction * ctx)496   explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
497     string encoding_tmp;
498     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &encoding_tmp));
499     OP_REQUIRES_OK(ctx, ParseUnicodeEncoding(encoding_tmp, &encoding_));
500     OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
501   }
502 
503   /**
504    * Encodes Unicode codepoints into the desired string representation.
505    *
506    * We lose a dimension while encoding, since a series of integer codepoints is
507    * encoded into a single string.
508    *
509    * This accepts two input tensors: a rank 1 tensor of code point values and
510    * a single rank 1 tensor of splits which determine where each string begins
511    * and ends from the provided code points.
512    */
Compute(OpKernelContext * context)513   void Compute(OpKernelContext* context) override {
514     // Get inputs
515     const Tensor& input_tensor = context->input(0);
516     const auto input_tensor_flat = input_tensor.flat<int32>();
517     const Tensor& input_splits = context->input(1);
518     const auto input_splits_flat = input_splits.flat<int64>();
519 
520     // Since we limit to a 2-D input (flat_values of rank 1 and a single splits
521     // tensor), our output dimension will be 1 with it's size equal to the
522     // number of splits (outer dimension or ragged tensor).
523     TensorShape output_shape({input_splits.dim_size(0) - 1});
524     Tensor* output_tensor;
525     OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
526                                                      &output_tensor));
527     auto output_tensor_flat = output_tensor->flat<string>();
528 
529     // Use a single index over the flattened input values tensor.
530     int idx = 0;
531     // Loop through our split dimension to create a new string at each split.
532     for (int i = 1; i < input_splits_flat.size(); ++i) {
533       icu::UnicodeString unicode_string;
534       icu::UnicodeStringAppendable appendable_unicode_string(unicode_string);
535       for (; idx < input_splits_flat(i); ++idx) {
536         int32 code_point = input_tensor_flat(idx);
537         // Check for invalid code point
538         if (code_point > UCHAR_MAX_VALUE || code_point < UCHAR_MIN_VALUE) {
539           if (error_options_.error_on_malformatting) {
540             context->CtxFailure(errors::InvalidArgument(
541                 "Code point value out of valid Unicode range."));
542             return;
543           } else if (!error_options_.elide_replacement) {
544             code_point = error_options_.subst;
545           }
546         }
547         appendable_unicode_string.appendCodePoint(code_point);
548       }
549       // Encode our string and save in the output.
550       string result;
551       Encode(encoding_, unicode_string, &result);
552       output_tensor_flat(i - 1) = result;
553     }
554   }
555 
556  private:
557   UnicodeEncoding encoding_;
558   ErrorOptions error_options_;
559 };
560 
561 REGISTER_KERNEL_BUILDER(Name("UnicodeEncode").Device(DEVICE_CPU),
562                         UnicodeEncodeOp);
563 
564 }  // namespace tensorflow
565