1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16 17"""Operations for working with string Tensors.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_parsing_ops 32from tensorflow.python.ops import gen_string_ops 33from tensorflow.python.ops import math_ops 34 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37# pylint: disable=g-bad-import-order 38from tensorflow.python.ops.gen_string_ops import * 39from tensorflow.python.util import compat as util_compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43# pylint: enable=g-bad-import-order 44# pylint: enable=wildcard-import 45 46 47# pylint: disable=redefined-builtin 48@tf_export("strings.regex_full_match") 49@dispatch.add_dispatch_support 50def regex_full_match(input, pattern, name=None): 51 r"""Match elements of `input` with regex `pattern`. 52 53 Args: 54 input: string `Tensor`, the source strings to process. 55 pattern: string or scalar string `Tensor`, regular expression to use, 56 see more details at https://github.com/google/re2/wiki/Syntax 57 name: Name of the op. 58 59 Returns: 60 bool `Tensor` of the same shape as `input` with match results. 61 """ 62 if isinstance(pattern, util_compat.bytes_or_text_types): 63 # When `pattern` is static through the life of the op we can 64 # use a version which performs the expensive regex compilation once at 65 # creation time. 66 return gen_string_ops.static_regex_full_match( 67 input=input, pattern=pattern, name=name) 68 return gen_string_ops.regex_full_match( 69 input=input, pattern=pattern, name=name) 70 71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ 72 73 74@tf_export( 75 "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) 76@dispatch.add_dispatch_support 77@deprecation.deprecated_endpoints("regex_replace") 78@dispatch.add_dispatch_support 79def regex_replace(input, pattern, rewrite, replace_global=True, name=None): 80 r"""Replace elements of `input` matching regex `pattern` with `rewrite`. 81 82 >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>", 83 ... "<[^>]+>", " ") 84 <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags. contains html '> 85 86 Args: 87 input: string `Tensor`, the source strings to process. 88 pattern: string or scalar string `Tensor`, regular expression to use, 89 see more details at https://github.com/google/re2/wiki/Syntax 90 rewrite: string or scalar string `Tensor`, value to use in match 91 replacement, supports backslash-escaped digits (\1 to \9) can be to insert 92 text matching corresponding parenthesized group. 93 replace_global: `bool`, if `True` replace all non-overlapping matches, 94 else replace only the first match. 95 name: A name for the operation (optional). 96 97 Returns: 98 string `Tensor` of the same shape as `input` with specified replacements. 99 """ 100 if (isinstance(pattern, util_compat.bytes_or_text_types) and 101 isinstance(rewrite, util_compat.bytes_or_text_types)): 102 # When `pattern` and `rewrite` are static through the life of the op we can 103 # use a version which performs the expensive regex compilation once at 104 # creation time. 105 return gen_string_ops.static_regex_replace( 106 input=input, pattern=pattern, 107 rewrite=rewrite, replace_global=replace_global, 108 name=name) 109 return gen_string_ops.regex_replace( 110 input=input, pattern=pattern, 111 rewrite=rewrite, replace_global=replace_global, 112 name=name) 113 114 115@tf_export("strings.format") 116@dispatch.add_dispatch_support 117def string_format(template, inputs, placeholder="{}", summarize=3, name=None): 118 r"""Formats a string template using a list of tensors. 119 120 Formats a string template using a list of tensors, abbreviating tensors by 121 only printing the first and last `summarize` elements of each dimension 122 (recursively). If formatting only one tensor into a template, the tensor does 123 not have to be wrapped in a list. 124 125 Example: 126 Formatting a single-tensor template: 127 128 >>> tensor = tf.range(5) 129 >>> tf.strings.format("tensor: {}, suffix", tensor) 130 <tf.Tensor: shape=(), dtype=string, numpy=b'tensor: [0 1 2 3 4], suffix'> 131 132 Formatting a multi-tensor template: 133 134 >>> tensor_a = tf.range(2) 135 >>> tensor_b = tf.range(1, 4, 2) 136 >>> tf.strings.format("a: {}, b: {}, suffix", (tensor_a, tensor_b)) 137 <tf.Tensor: shape=(), dtype=string, numpy=b'a: [0 1], b: [1 3], suffix'> 138 139 140 Args: 141 template: A string template to format tensor values into. 142 inputs: A list of `Tensor` objects, or a single Tensor. 143 The list of tensors to format into the template string. If a solitary 144 tensor is passed in, the input tensor will automatically be wrapped as a 145 list. 146 placeholder: An optional `string`. Defaults to `{}`. 147 At each placeholder occurring in the template, a subsequent tensor 148 will be inserted. 149 summarize: An optional `int`. Defaults to `3`. 150 When formatting the tensors, show the first and last `summarize` 151 entries of each tensor dimension (recursively). If set to -1, all 152 elements of the tensor will be shown. 153 name: A name for the operation (optional). 154 155 Returns: 156 A scalar `Tensor` of type `string`. 157 158 Raises: 159 ValueError: if the number of placeholders does not match the number of 160 inputs. 161 """ 162 # If there is only one tensor to format, we will automatically wrap it in a 163 # list to simplify the user experience 164 if tensor_util.is_tf_type(inputs): 165 inputs = [inputs] 166 if template.count(placeholder) != len(inputs): 167 raise ValueError(f"The template expects {template.count(placeholder)} " 168 f"tensors, but the inputs only has {len(inputs)}. " 169 "Please ensure the number of placeholders in template " 170 "matches inputs length.") 171 172 return gen_string_ops.string_format(inputs, 173 template=template, 174 placeholder=placeholder, 175 summarize=summarize, 176 name=name) 177 178 179# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 180# defines a wrapper for this function. 181def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name 182 """Split elements of `source` based on `delimiter` into a `SparseTensor`. 183 184 Let N be the size of source (typically N will be the batch size). Split each 185 element of `source` based on `delimiter` and return a `SparseTensor` 186 containing the split tokens. Empty tokens are ignored. 187 188 If `sep` is an empty string, each element of the `source` is split 189 into individual strings, each containing one byte. (This includes splitting 190 multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is 191 treated as a set of delimiters with each considered a potential split point. 192 193 For example: 194 N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output 195 will be 196 197 st.indices = [0, 0; 198 0, 1; 199 1, 0; 200 1, 1; 201 1, 2] 202 st.shape = [2, 3] 203 st.values = ['hello', 'world', 'a', 'b', 'c'] 204 205 Args: 206 source: `1-D` string `Tensor`, the strings to split. 207 sep: `0-D` string `Tensor`, the delimiter character, the string should 208 be length 0 or 1. Default is ' '. 209 skip_empty: A `bool`. If `True`, skip the empty strings from the result. 210 delimiter: deprecated alias for `sep`. 211 212 Raises: 213 ValueError: If delimiter is not a string. 214 215 Returns: 216 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 217 The first column of the indices corresponds to the row in `source` and the 218 second column corresponds to the index of the split component in this row. 219 """ 220 delimiter = deprecation.deprecated_argument_lookup( 221 "sep", sep, "delimiter", delimiter) 222 223 if delimiter is None: 224 delimiter = " " 225 delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) 226 source = ops.convert_to_tensor(source, dtype=dtypes.string) 227 228 indices, values, shape = gen_string_ops.string_split( 229 source, delimiter=delimiter, skip_empty=skip_empty) 230 indices.set_shape([None, 2]) 231 values.set_shape([None]) 232 shape.set_shape([2]) 233 return sparse_tensor.SparseTensor(indices, values, shape) 234 235 236# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 237# defines a wrapper for this function. 238def string_split_v2(source, sep=None, maxsplit=-1): 239 """Split elements of `source` based on `sep` into a `SparseTensor`. 240 241 Let N be the size of source (typically N will be the batch size). Split each 242 element of `source` based on `sep` and return a `SparseTensor` 243 containing the split tokens. Empty tokens are ignored. 244 245 For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', 246 then the output will be 247 248 st.indices = [0, 0; 249 0, 1; 250 1, 0; 251 1, 1; 252 1, 2] 253 st.shape = [2, 3] 254 st.values = ['hello', 'world', 'a', 'b', 'c'] 255 256 If `sep` is given, consecutive delimiters are not grouped together and are 257 deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and 258 sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty 259 string, consecutive whitespace are regarded as a single separator, and the 260 result will contain no empty strings at the start or end if the string has 261 leading or trailing whitespace. 262 263 Note that the above mentioned behavior matches python's str.split. 264 265 Args: 266 source: `1-D` string `Tensor`, the strings to split. 267 sep: `0-D` string `Tensor`, the delimiter character. 268 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. 269 270 Raises: 271 ValueError: If sep is not a string. 272 273 Returns: 274 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 275 The first column of the indices corresponds to the row in `source` and the 276 second column corresponds to the index of the split component in this row. 277 """ 278 if sep is None: 279 sep = "" 280 sep = ops.convert_to_tensor(sep, dtype=dtypes.string) 281 source = ops.convert_to_tensor(source, dtype=dtypes.string) 282 283 indices, values, shape = gen_string_ops.string_split_v2( 284 source, sep=sep, maxsplit=maxsplit) 285 indices.set_shape([None, 2]) 286 values.set_shape([None]) 287 shape.set_shape([2]) 288 return sparse_tensor.SparseTensor(indices, values, shape) 289 290 291def _reduce_join_reduction_dims(x, axis): 292 """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise.""" 293 if axis is not None: 294 return axis 295 else: 296 # Fast path: avoid creating Rank and Range ops if ndims is known. 297 if x.get_shape().ndims is not None: 298 return constant_op.constant( 299 np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32) 300 301 # Otherwise, we rely on Range and Rank to do the right thing at run-time. 302 return math_ops.range(array_ops.rank(x) - 1, -1, -1) 303 304 305@tf_export(v1=["strings.reduce_join", "reduce_join"]) 306@dispatch.add_dispatch_support 307@deprecation.deprecated_args(None, 308 "keep_dims is deprecated, use keepdims instead", 309 "keep_dims") 310@deprecation.deprecated_endpoints("reduce_join") 311def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring 312 keep_dims=None, 313 separator="", 314 name=None, 315 reduction_indices=None, 316 keepdims=None): 317 keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, 318 "keep_dims", keep_dims) 319 if keep_dims is None: 320 keep_dims = False 321 axis = deprecation.deprecated_argument_lookup("axis", axis, 322 "reduction_indices", 323 reduction_indices) 324 return reduce_join_v2( 325 inputs=inputs, 326 axis=axis, 327 keepdims=keepdims, 328 separator=separator, 329 name=name) 330 331 332@tf_export("strings.reduce_join", v1=[]) 333@dispatch.add_dispatch_support 334def reduce_join_v2( # pylint: disable=missing-docstring 335 inputs, 336 axis=None, 337 keepdims=False, 338 separator="", 339 name=None): 340 """Joins all strings into a single string, or joins along an axis. 341 342 This is the reduction operation for the elementwise `tf.strings.join` op. 343 344 >>> tf.strings.reduce_join([['abc','123'], 345 ... ['def','456']]).numpy() 346 b'abc123def456' 347 >>> tf.strings.reduce_join([['abc','123'], 348 ... ['def','456']], axis=-1).numpy() 349 array([b'abc123', b'def456'], dtype=object) 350 >>> tf.strings.reduce_join([['abc','123'], 351 ... ['def','456']], 352 ... axis=-1, 353 ... separator=" ").numpy() 354 array([b'abc 123', b'def 456'], dtype=object) 355 356 Args: 357 inputs: A `tf.string` tensor. 358 axis: Which axis to join along. The default behavior is to join all 359 elements, producing a scalar. 360 keepdims: If true, retains reduced dimensions with length 1. 361 separator: a string added between each string being joined. 362 name: A name for the operation (optional). 363 364 Returns: 365 A `tf.string` tensor. 366 """ 367 with ops.name_scope(None, "ReduceJoin", [inputs, axis]): 368 inputs_t = ops.convert_to_tensor(inputs) 369 axis = _reduce_join_reduction_dims(inputs_t, axis) 370 return gen_string_ops.reduce_join( 371 inputs=inputs_t, 372 reduction_indices=axis, 373 keep_dims=keepdims, 374 separator=separator, 375 name=name) 376 377reduce_join.__doc__ = reduce_join_v2.__doc__ 378 379 380# This wrapper provides backwards compatibility for code that predates the 381# unit argument and that passed 'name' as a positional argument. 382@tf_export(v1=["strings.length"]) 383@dispatch.add_dispatch_support 384def string_length(input, name=None, unit="BYTE"): 385 """Computes the length of each string given in the input tensor. 386 387 >>> strings = tf.constant(['Hello','TensorFlow', '']) 388 >>> tf.strings.length(strings).numpy() # default counts bytes 389 array([ 5, 10, 4], dtype=int32) 390 >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy() 391 array([ 5, 10, 1], dtype=int32) 392 393 Args: 394 input: A `Tensor` of type `string`. The strings for which to compute the 395 length for each element. 396 name: A name for the operation (optional). 397 unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to 398 `"BYTE"`. The unit that is counted to compute string length. One of: 399 `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for 400 the number of UTF-8 encoded Unicode code points in each string). Results 401 are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain 402 structurally valid UTF-8. 403 404 Returns: 405 A `Tensor` of type `int32`, containing the length of the input string in 406 the same element of the input tensor. 407 """ 408 return gen_string_ops.string_length(input, unit=unit, name=name) 409 410 411@tf_export("strings.length", v1=[]) 412@dispatch.add_dispatch_support 413def string_length_v2(input, unit="BYTE", name=None): 414 return gen_string_ops.string_length(input, unit=unit, name=name) 415 416 417string_length_v2.__doc__ = gen_string_ops.string_length.__doc__ 418 419 420@tf_export(v1=["substr"]) 421@dispatch.add_dispatch_support 422@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") 423def substr_deprecated(input, pos, len, name=None, unit="BYTE"): 424 return substr(input, pos, len, name=name, unit=unit) 425 426substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ 427 428 429@tf_export(v1=["strings.substr"]) 430@dispatch.add_dispatch_support 431def substr(input, pos, len, name=None, unit="BYTE"): 432 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 433 434substr.__doc__ = gen_string_ops.substr.__doc__ 435 436 437@tf_export("strings.substr", v1=[]) 438@dispatch.add_dispatch_support 439def substr_v2(input, pos, len, unit="BYTE", name=None): 440 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 441 442substr_v2.__doc__ = gen_string_ops.substr.__doc__ 443 444 445ops.NotDifferentiable("RegexReplace") 446ops.NotDifferentiable("StringToHashBucket") 447ops.NotDifferentiable("StringToHashBucketFast") 448ops.NotDifferentiable("StringToHashBucketStrong") 449ops.NotDifferentiable("ReduceJoin") 450ops.NotDifferentiable("StringJoin") 451ops.NotDifferentiable("StringSplit") 452ops.NotDifferentiable("AsString") 453ops.NotDifferentiable("EncodeBase64") 454ops.NotDifferentiable("DecodeBase64") 455 456 457@tf_export("strings.to_number", v1=[]) 458@dispatch.add_dispatch_support 459def string_to_number(input, out_type=dtypes.float32, name=None): 460 r"""Converts each string in the input Tensor to the specified numeric type. 461 462 (Note that int32 overflow results in an error while float overflow 463 results in a rounded value.) 464 465 Examples: 466 467 >>> tf.strings.to_number("1.55") 468 <tf.Tensor: shape=(), dtype=float32, numpy=1.55> 469 >>> tf.strings.to_number("3", tf.int32) 470 <tf.Tensor: shape=(), dtype=int32, numpy=3> 471 472 Args: 473 input: A `Tensor` of type `string`. 474 out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32, 475 tf.int64`. Defaults to `tf.float32`. 476 The numeric type to interpret each string in `string_tensor` as. 477 name: A name for the operation (optional). 478 479 Returns: 480 A `Tensor` of type `out_type`. 481 """ 482 return gen_parsing_ops.string_to_number(input, out_type, name) 483 484 485@tf_export(v1=["strings.to_number", "string_to_number"]) 486@dispatch.add_dispatch_support 487def string_to_number_v1( 488 string_tensor=None, 489 out_type=dtypes.float32, 490 name=None, 491 input=None): 492 string_tensor = deprecation.deprecated_argument_lookup( 493 "input", input, "string_tensor", string_tensor) 494 return gen_parsing_ops.string_to_number(string_tensor, out_type, name) 495 496string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__ 497 498 499@tf_export("strings.to_hash_bucket", v1=[]) 500@dispatch.add_dispatch_support 501def string_to_hash_bucket(input, num_buckets, name=None): 502 # pylint: disable=line-too-long 503 r"""Converts each string in the input Tensor to its hash mod by a number of buckets. 504 505 The hash function is deterministic on the content of the string within the 506 process. 507 508 Note that the hash function may change from time to time. 509 This functionality will be deprecated and it's recommended to use 510 `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`. 511 512 Examples: 513 514 >>> tf.strings.to_hash_bucket(["Hello", "TensorFlow", "2.x"], 3) 515 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 0, 1])> 516 517 Args: 518 input: A `Tensor` of type `string`. 519 num_buckets: An `int` that is `>= 1`. The number of buckets. 520 name: A name for the operation (optional). 521 522 Returns: 523 A `Tensor` of type `int64`. 524 """ 525 # pylint: enable=line-too-long 526 return gen_string_ops.string_to_hash_bucket(input, num_buckets, name) 527 528 529@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) 530@dispatch.add_dispatch_support 531def string_to_hash_bucket_v1( 532 string_tensor=None, 533 num_buckets=None, 534 name=None, 535 input=None): 536 string_tensor = deprecation.deprecated_argument_lookup( 537 "input", input, "string_tensor", string_tensor) 538 return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name) 539 540string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__ 541 542 543@tf_export("strings.join", v1=["strings.join", "string_join"]) 544@dispatch.add_dispatch_support 545@deprecation.deprecated_endpoints("string_join") 546@dispatch.add_dispatch_support 547def string_join(inputs, separator="", name=None): 548 """Perform element-wise concatenation of a list of string tensors. 549 550 Given a list of string tensors of same shape, performs element-wise 551 concatenation of the strings of the same index in all tensors. 552 553 554 >>> tf.strings.join(['abc','def']).numpy() 555 b'abcdef' 556 >>> tf.strings.join([['abc','123'], 557 ... ['def','456'], 558 ... ['ghi','789']]).numpy() 559 array([b'abcdefghi', b'123456789'], dtype=object) 560 >>> tf.strings.join([['abc','123'], 561 ... ['def','456']], 562 ... separator=" ").numpy() 563 array([b'abc def', b'123 456'], dtype=object) 564 565 The reduction version of this elementwise operation is 566 `tf.strings.reduce_join` 567 568 Args: 569 inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype. 570 separator: A string added between each string being joined. 571 name: A name for the operation (optional). 572 573 Returns: 574 A `tf.string` tensor. 575 """ 576 return gen_string_ops.string_join(inputs, separator=separator, name=name) 577