1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17 #include <cstddef>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "unicode/appendable.h" // TF:icu
25 #include "unicode/schriter.h" // TF:icu
26 #include "unicode/uchar.h" // TF:icu
27 #include "unicode/ucnv.h" // TF:icu
28 #include "unicode/ucnv_err.h" // TF:icu
29 #include "unicode/umachine.h" // TF:icu
30 #include "unicode/uniset.h" // TF:icu
31 #include "unicode/unistr.h" // TF:icu
32 #include "unicode/uset.h" // TF:icu
33 #include "unicode/utypes.h" // TF:icu
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/kernel_def_builder.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/register_types.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_types.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/kernels/string_util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/util/bcast.h"
49 #include "tensorflow/core/util/ptr_util.h"
50
51 namespace tensorflow {
52 namespace {
53
Encode(const UnicodeEncoding encoding,const icu::UnicodeString & in,string * out)54 void Encode(const UnicodeEncoding encoding, const icu::UnicodeString& in,
55 string* out) {
56 if (encoding == UnicodeEncoding::UTF8) {
57 out->clear();
58 in.toUTF8String(*out);
59 } else if (encoding == UnicodeEncoding::UTF16BE) {
60 // TODO(gbillock): consider using the
61 // extract(char *dest, int32_t destCapacity, UConverter *cnv)
62 // for UTF16/32
63 out->clear(); // subtle: must come before reserve()
64 out->reserve(2 * in.length() + 1);
65 const char16_t* buf = in.getBuffer();
66 for (int i = 0; i < in.length(); ++i) {
67 // Emit big-endian encoding for UTF-16 always.
68 out->push_back((buf[i] & 0xFF00) >> 8);
69 out->push_back(buf[i] & 0x00FF);
70 }
71 } else if (encoding == UnicodeEncoding::UTF32BE) {
72 out->clear(); // subtle: must come before reserve()
73 out->reserve(4 * in.countChar32() + 1);
74 icu::StringCharacterIterator it(in);
75 UChar32 ch;
76 while (it.hasNext()) {
77 ch = it.next32PostInc();
78 out->push_back((ch & 0xFF000000) >> 24);
79 out->push_back((ch & 0x00FF0000) >> 16);
80 out->push_back((ch & 0x0000FF00) >> 8);
81 out->push_back((ch & 0x000000FF));
82 }
83 }
84 }
85
86 // This error callback is only useful for finding illegal encoding errors when
87 // we want to be strict -- otherwise illegal encodings are replaced on read
88 // with 0xFFFD and signaled to the callback.
unicode_error_callback(const void * context,UConverterToUnicodeArgs * args,const char * codeUnits,int32_t length,UConverterCallbackReason reason,UErrorCode * pErrorCode)89 void unicode_error_callback(const void* context, UConverterToUnicodeArgs* args,
90 const char* codeUnits, int32_t length,
91 UConverterCallbackReason reason,
92 UErrorCode* pErrorCode) {
93 // Careful: this depends on setting up the context settings when the
94 // callback is registered.
95 bool* format_error = const_cast<bool*>(static_cast<const bool*>(context));
96
97 if (reason == UCNV_UNASSIGNED || reason == UCNV_ILLEGAL ||
98 reason == UCNV_IRREGULAR) {
99 *format_error = true;
100 }
101
102 // Side note: the default behavior in this case is that without a substitution
103 // made by the callback, the UConverter will signal an error to the iterator
104 // making the string iteration bail out. Instead, forward to the built-in
105 // substitution handler.
106 UCNV_TO_U_CALLBACK_SUBSTITUTE(nullptr, args, codeUnits, length, reason,
107 pErrorCode);
108 }
109
110 // Iterates through a source string given the provided input UConverter specific
111 // to the encoding for that string. Calls a provided callback for each codepoint
112 // consumed. Provides the callback with the codepoint and the number of bytes
113 // consumed from the input string to produce it. If there are invalid encoding
114 // loci in the source string, they will be provided as a 0xFFFD codepoint to
115 // the callback, unless the "fail_on_formatting_error" arg is set, in which
116 // case the callback will be passed the signal that there is such an invalid
117 // encoding position.
118 // callback: function(UChar32 codepoint, int num_bytes_consumed_from_source_str,
119 // bool fatal_format_error)
IterateUnicodeString(const string & str,UConverter * converter,std::function<void (UChar32,int,bool)> callback)120 void IterateUnicodeString(const string& str, UConverter* converter,
121 std::function<void(UChar32, int, bool)> callback) {
122 const char* source = str.data();
123 const char* limit = str.data() + str.length();
124 UErrorCode status = U_ZERO_ERROR;
125
126 UConverterToUCallback oldAction = nullptr;
127 const void* oldContext = nullptr;
128 bool format_error = false;
129
130 // Subtle. You can't make a function pointer from a std::function. :-(
131 // Instead, we pass the boolean pointer as the "context" object.
132 ucnv_setToUCallBack(converter, unicode_error_callback, &format_error,
133 &oldAction, &oldContext, &status);
134 if (U_FAILURE(status)) {
135 LOG(ERROR) << "Could not set unicode error callback on converter";
136 return;
137 }
138
139 while (source < limit) {
140 const char* source_pre_fetch = source;
141 // Note: ucnv_getNextUChar returns 0xFFFD on an encoding error.
142 UChar32 next_char = ucnv_getNextUChar(converter, &source, limit, &status);
143 if (U_FAILURE(status)) {
144 source = limit;
145 }
146 int bytes_consumed = source - source_pre_fetch;
147 callback(next_char, bytes_consumed, format_error);
148 format_error = false;
149 }
150
151 ucnv_setToUCallBack(converter, oldAction, oldContext, nullptr, nullptr,
152 &status);
153 }
154
155 // Lifecycle wrapper for UConverter making it easier to use with thread_local.
156 // TODO(gbillock): Consider whether to use the higher-level convert API and
157 // create a specialized fast code path for UTF8.
158 class WrappedConverter {
159 public:
WrappedConverter()160 WrappedConverter() {}
161
~WrappedConverter()162 ~WrappedConverter() {
163 if (converter_) {
164 ucnv_close(converter_);
165 }
166 }
167
init(const string & name)168 void init(const string& name) {
169 if (converter_ && name == name_) {
170 // Note: this reset is not typically needed, but if not done, then in some
171 // cases the cached converter will maintain state of input endianness
172 // which isn't valid from input to input in every batched case.
173 ucnv_reset(converter_);
174 return;
175 }
176
177 if (converter_) {
178 ucnv_close(converter_);
179 converter_ = nullptr;
180 name_ = "";
181 }
182
183 UErrorCode status = U_ZERO_ERROR;
184 converter_ = ucnv_open(name.c_str(), &status);
185 if (U_FAILURE(status)) {
186 if (converter_) {
187 ucnv_close(converter_);
188 converter_ = nullptr;
189 }
190 } else {
191 name_ = name;
192 }
193 }
194
195 UConverter* converter_ = nullptr;
196 string name_;
197 };
198
199 struct ErrorOptions {
200 UChar32 subst = 0xFFFD;
201 bool elide_replacement = false;
202 bool replace_control_chars = false;
203 bool error_on_malformatting = false;
204 };
205
GetErrorOptions(OpKernelConstruction * ctx,ErrorOptions * out)206 Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) {
207 *out = ErrorOptions();
208
209 string error_policy;
210 TF_RETURN_IF_ERROR(ctx->GetAttr("errors", &error_policy));
211
212 if (error_policy == "replace") {
213 out->elide_replacement = false;
214 } else if (error_policy == "ignore") {
215 out->elide_replacement = true;
216 } else if (error_policy == "strict") {
217 out->error_on_malformatting = true;
218 } else {
219 return errors::InvalidArgument(
220 "errors policy must be one of 'strict', 'replace', or 'ignore'");
221 }
222
223 int32 replacement_char;
224 TF_RETURN_IF_ERROR(ctx->GetAttr("replacement_char", &replacement_char));
225
226 if (replacement_char >= UCHAR_MIN_VALUE &&
227 replacement_char <= UCHAR_MAX_VALUE) {
228 out->subst = replacement_char;
229 } else {
230 return errors::InvalidArgument(
231 "replacement_char out of unicode codepoint range");
232 }
233
234 if (ctx->HasAttr("replace_control_characters")) {
235 TF_RETURN_IF_ERROR(ctx->GetAttr("replace_control_characters",
236 &(out->replace_control_chars)));
237 }
238
239 return Status::OK();
240 }
241
ShouldHandleFormatError(const ErrorOptions & error_options,UChar32 ch,bool format_error)242 inline bool ShouldHandleFormatError(const ErrorOptions& error_options,
243 UChar32 ch, bool format_error) {
244 return ((error_options.replace_control_chars && ch <= 0x1F) || format_error);
245 }
246
247 } // namespace
248
249 class UnicodeTranscodeOp : public OpKernel {
250 public:
UnicodeTranscodeOp(OpKernelConstruction * ctx)251 explicit UnicodeTranscodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
252 OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
253
254 string output_encoding;
255 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &output_encoding));
256 OP_REQUIRES_OK(ctx,
257 ParseUnicodeEncoding(output_encoding, &output_encoding_));
258
259 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_));
260 // Make a temporary UConverter to ensure it will create without error
261 // at execution time (and to warm any data caches the converter needs).
262 // This instance is not used.
263 std::unique_ptr<WrappedConverter> input_encoder =
264 absl::make_unique<WrappedConverter>();
265 input_encoder->init(input_encoding_);
266 OP_REQUIRES(ctx, input_encoder->converter_,
267 errors::InvalidArgument(
268 "Could not create converter for input encoding: " +
269 input_encoding_));
270 }
271
Compute(OpKernelContext * ctx)272 void Compute(OpKernelContext* ctx) override {
273 const Tensor* input_tensor;
274 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
275
276 static thread_local std::unique_ptr<WrappedConverter> input_encoder;
277 if (!input_encoder) {
278 input_encoder.reset(new WrappedConverter());
279 }
280 input_encoder->init(input_encoding_);
281 OP_REQUIRES(ctx, input_encoder->converter_,
282 errors::InvalidArgument(
283 "Could not create converter for input encoding: " +
284 input_encoding_));
285
286 // Output may be forwardable from input, in which case work in-place.
287 Tensor* output_tensor;
288 std::unique_ptr<Tensor> maybe_forwarded =
289 ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
290 tensorflow::DT_STRING, input_tensor->shape(),
291 ctx->input_memory_type(0), ctx->input_alloc_attr(0));
292 if (maybe_forwarded) {
293 output_tensor = maybe_forwarded.get();
294 OP_REQUIRES_OK(ctx, ctx->set_output("output", *output_tensor));
295 } else {
296 OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
297 &output_tensor));
298 output_tensor->flat<string>() = input_tensor->flat<string>();
299 }
300
301 auto output_flat = output_tensor->flat<string>();
302 bool found_any_format_error = false;
303 for (size_t i = 0; i < output_flat.size(); ++i) {
304 Transcode(&(output_flat(i)), input_encoder->converter_,
305 &found_any_format_error);
306 }
307 if (error_options_.error_on_malformatting && found_any_format_error) {
308 ctx->CtxFailure(
309 errors::InvalidArgument("Invalid formatting on input string"));
310 }
311 }
312
313 private:
314 // Consume a codepoint from the input string and add it to the buffer.
315 // This function takes care of any replacement configuration on invalid or
316 // out-of-range inputs.
TranslateCodepoints(icu::UnicodeString * s,bool * found_any_format_error,UChar32 ch,int src_bytes,bool format_error)317 void TranslateCodepoints(icu::UnicodeString* s, bool* found_any_format_error,
318 UChar32 ch, int src_bytes, bool format_error) {
319 if (ShouldHandleFormatError(error_options_, ch, format_error)) {
320 *found_any_format_error = true;
321 if (error_options_.elide_replacement) {
322 return;
323 } else {
324 ch = error_options_.subst;
325 }
326 }
327 s->append(ch);
328 }
329
330 // Transcode the string from input encoding to the output_encoding_. If
331 // non-valid characters are encountered, use the subst_/elide_replacement_
332 // config to handle them.
Transcode(string * s,UConverter * input_encoder,bool * found_any_format_error)333 void Transcode(string* s, UConverter* input_encoder,
334 bool* found_any_format_error) {
335 icu::UnicodeString source;
336 IterateUnicodeString(
337 *s, input_encoder,
338 std::bind(&UnicodeTranscodeOp::TranslateCodepoints, this, &source,
339 found_any_format_error, std::placeholders::_1,
340 std::placeholders::_2, std::placeholders::_3));
341
342 Encode(output_encoding_, source, s);
343 }
344
345 string input_encoding_;
346 ErrorOptions error_options_;
347 UnicodeEncoding output_encoding_ = UnicodeEncoding::UTF8;
348 };
349
350 REGISTER_KERNEL_BUILDER(Name("UnicodeTranscode").Device(DEVICE_CPU),
351 UnicodeTranscodeOp);
352
353 class UnicodeDecodeBaseOp : public OpKernel {
354 public:
UnicodeDecodeBaseOp(OpKernelConstruction * ctx,bool generate_offsets)355 explicit UnicodeDecodeBaseOp(OpKernelConstruction* ctx, bool generate_offsets)
356 : OpKernel(ctx), generate_offsets_(generate_offsets) {
357 OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
358 OP_REQUIRES_OK(ctx, ctx->GetAttr("input_encoding", &input_encoding_));
359 // Make a temporary UConverter to ensure it will create without error
360 // at execution time (and to warm any data caches the converter needs).
361 // This instance is not used.
362 std::unique_ptr<WrappedConverter> input_encoder =
363 absl::make_unique<WrappedConverter>();
364 input_encoder->init(input_encoding_);
365 OP_REQUIRES(ctx, input_encoder->converter_,
366 errors::InvalidArgument(
367 "Could not create converter for input encoding: " +
368 input_encoding_));
369 }
370
Decode(OpKernelContext * ctx,std::vector<UChar32> * char_values,std::vector<int64> * offset_values,int * current_offset,int64 * next_row_split,UChar32 char_value,int char_length,bool found_any_format_error)371 void Decode(OpKernelContext* ctx, std::vector<UChar32>* char_values,
372 std::vector<int64>* offset_values, int* current_offset,
373 int64* next_row_split, UChar32 char_value, int char_length,
374 bool found_any_format_error) {
375 if (error_options_.error_on_malformatting && found_any_format_error) {
376 ctx->CtxFailure(
377 errors::InvalidArgument("Invalid formatting on input string"));
378 }
379 UChar32 decoded_value = char_value;
380 if (ShouldHandleFormatError(error_options_, char_value,
381 found_any_format_error)) {
382 if (error_options_.elide_replacement && (offset_values != nullptr)) {
383 *current_offset += char_length;
384 return;
385 } else {
386 decoded_value = error_options_.subst;
387 }
388 }
389
390 // Emit the char value.
391 char_values->push_back(decoded_value);
392
393 // Emit the byte offset
394 if (offset_values != nullptr) {
395 offset_values->push_back(*current_offset);
396 *current_offset += char_length;
397 }
398 *next_row_split += 1;
399 }
400
Compute(OpKernelContext * ctx)401 void Compute(OpKernelContext* ctx) override {
402 const Tensor* input_tensor;
403 OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
404
405 // Go through all the strings in `input`.
406 const auto& input_vec = input_tensor->flat<string>();
407
408 std::unique_ptr<WrappedConverter> input_encoder =
409 absl::make_unique<WrappedConverter>();
410 input_encoder->init(input_encoding_);
411 OP_REQUIRES(ctx, input_encoder->converter_,
412 errors::InvalidArgument(
413 "Could not create converter for input encoding: " +
414 input_encoding_));
415
416 std::vector<UChar32> char_values;
417 std::vector<int64> offset_values;
418
419 Tensor* output_row_splits;
420 OP_REQUIRES_OK(ctx, ctx->allocate_output("row_splits",
421 {input_tensor->NumElements() + 1},
422 &output_row_splits));
423 auto out_row_splits = output_row_splits->vec<int64>();
424
425 int row_split_index = 0;
426 int64 next_row_split = 0;
427 for (int i = 0; i < input_vec.size(); ++i) {
428 const string& input = input_vec(i);
429 // Convert input strings into unicode values. Output to a list of
430 // char_values, record row splits and char_to_byte_starts, which are all
431 // the fields needed to construct a RaggedTensor.
432 out_row_splits(row_split_index) = next_row_split;
433 row_split_index++;
434 int current_offset = 0;
435 IterateUnicodeString(
436 input, input_encoder->converter_,
437 std::bind(&UnicodeDecodeBaseOp::Decode, this, ctx, &char_values,
438 &offset_values, ¤t_offset, &next_row_split,
439 std::placeholders::_1, std::placeholders::_2,
440 std::placeholders::_3));
441 }
442 out_row_splits(row_split_index) = next_row_split;
443
444 Tensor* output_char_values;
445 OP_REQUIRES_OK(
446 ctx, ctx->allocate_output("char_values",
447 {static_cast<int64>(char_values.size())},
448 &output_char_values));
449 auto out_char_values = output_char_values->vec<int32>();
450 if (generate_offsets_) {
451 DCHECK(offset_values.size() == char_values.size());
452 Tensor* output_offset_values;
453 OP_REQUIRES_OK(
454 ctx, ctx->allocate_output("char_to_byte_starts",
455 {static_cast<int64>(offset_values.size())},
456 &output_offset_values));
457 auto out_offset_values = output_offset_values->vec<int64>();
458
459 // Load output tensors from intermediate value arrays.
460 for (int i = 0; i < char_values.size(); ++i) {
461 out_char_values(i) = static_cast<int32>(char_values[i]);
462 out_offset_values(i) = offset_values[i];
463 }
464 } else {
465 for (int i = 0; i < char_values.size(); ++i) {
466 out_char_values(i) = static_cast<int32>(char_values[i]);
467 }
468 }
469 }
470
471 private:
472 string input_encoding_;
473 ErrorOptions error_options_;
474 bool generate_offsets_ = false;
475 };
476
477 class UnicodeDecodeOp : public UnicodeDecodeBaseOp {
478 public:
UnicodeDecodeOp(OpKernelConstruction * ctx)479 explicit UnicodeDecodeOp(OpKernelConstruction* ctx)
480 : UnicodeDecodeBaseOp(ctx, false) {}
481 };
482
483 class UnicodeDecodeWithOffsetsOp : public UnicodeDecodeBaseOp {
484 public:
UnicodeDecodeWithOffsetsOp(OpKernelConstruction * ctx)485 explicit UnicodeDecodeWithOffsetsOp(OpKernelConstruction* ctx)
486 : UnicodeDecodeBaseOp(ctx, true) {}
487 };
488
489 REGISTER_KERNEL_BUILDER(Name("UnicodeDecode").Device(DEVICE_CPU),
490 UnicodeDecodeOp);
491 REGISTER_KERNEL_BUILDER(Name("UnicodeDecodeWithOffsets").Device(DEVICE_CPU),
492 UnicodeDecodeWithOffsetsOp);
493
494 class UnicodeEncodeOp : public OpKernel {
495 public:
UnicodeEncodeOp(OpKernelConstruction * ctx)496 explicit UnicodeEncodeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
497 string encoding_tmp;
498 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_encoding", &encoding_tmp));
499 OP_REQUIRES_OK(ctx, ParseUnicodeEncoding(encoding_tmp, &encoding_));
500 OP_REQUIRES_OK(ctx, GetErrorOptions(ctx, &error_options_));
501 }
502
503 /**
504 * Encodes Unicode codepoints into the desired string representation.
505 *
506 * We lose a dimension while encoding, since a series of integer codepoints is
507 * encoded into a single string.
508 *
509 * This accepts two input tensors: a rank 1 tensor of code point values and
510 * a single rank 1 tensor of splits which determine where each string begins
511 * and ends from the provided code points.
512 */
Compute(OpKernelContext * context)513 void Compute(OpKernelContext* context) override {
514 // Get inputs
515 const Tensor& input_tensor = context->input(0);
516 const auto input_tensor_flat = input_tensor.flat<int32>();
517 const Tensor& input_splits = context->input(1);
518 const auto input_splits_flat = input_splits.flat<int64>();
519
520 // Since we limit to a 2-D input (flat_values of rank 1 and a single splits
521 // tensor), our output dimension will be 1 with it's size equal to the
522 // number of splits (outer dimension or ragged tensor).
523 TensorShape output_shape({input_splits.dim_size(0) - 1});
524 Tensor* output_tensor;
525 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
526 &output_tensor));
527 auto output_tensor_flat = output_tensor->flat<string>();
528
529 // Use a single index over the flattened input values tensor.
530 int idx = 0;
531 // Loop through our split dimension to create a new string at each split.
532 for (int i = 1; i < input_splits_flat.size(); ++i) {
533 icu::UnicodeString unicode_string;
534 icu::UnicodeStringAppendable appendable_unicode_string(unicode_string);
535 for (; idx < input_splits_flat(i); ++idx) {
536 int32 code_point = input_tensor_flat(idx);
537 // Check for invalid code point
538 if (code_point > UCHAR_MAX_VALUE || code_point < UCHAR_MIN_VALUE) {
539 if (error_options_.error_on_malformatting) {
540 context->CtxFailure(errors::InvalidArgument(
541 "Code point value out of valid Unicode range."));
542 return;
543 } else if (!error_options_.elide_replacement) {
544 code_point = error_options_.subst;
545 }
546 }
547 appendable_unicode_string.appendCodePoint(code_point);
548 }
549 // Encode our string and save in the output.
550 string result;
551 Encode(encoding_, unicode_string, &result);
552 output_tensor_flat(i - 1) = result;
553 }
554 }
555
556 private:
557 UnicodeEncoding encoding_;
558 ErrorOptions error_options_;
559 };
560
561 REGISTER_KERNEL_BUILDER(Name("UnicodeEncode").Device(DEVICE_CPU),
562 UnicodeEncodeOp);
563
564 } // namespace tensorflow
565