1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16 17"""Operations for working with string Tensors.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_parsing_ops 32from tensorflow.python.ops import gen_string_ops 33from tensorflow.python.ops import math_ops 34 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37# pylint: disable=g-bad-import-order 38from tensorflow.python.ops.gen_string_ops import * 39from tensorflow.python.util import compat as util_compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43# pylint: enable=g-bad-import-order 44# pylint: enable=wildcard-import 45 46 47# pylint: disable=redefined-builtin 48@tf_export("strings.regex_full_match") 49@dispatch.add_dispatch_support 50def regex_full_match(input, pattern, name=None): 51 r"""Match elements of `input` with regex `pattern`. 52 53 Args: 54 input: string `Tensor`, the source strings to process. 55 pattern: string or scalar string `Tensor`, regular expression to use, 56 see more details at https://github.com/google/re2/wiki/Syntax 57 name: Name of the op. 58 59 Returns: 60 bool `Tensor` of the same shape as `input` with match results. 61 """ 62 if isinstance(pattern, util_compat.bytes_or_text_types): 63 # When `pattern` is static through the life of the op we can 64 # use a version which performs the expensive regex compilation once at 65 # creation time. 66 return gen_string_ops.static_regex_full_match( 67 input=input, pattern=pattern, name=name) 68 return gen_string_ops.regex_full_match( 69 input=input, pattern=pattern, name=name) 70 71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ 72 73 74@tf_export( 75 "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) 76@dispatch.add_dispatch_support 77@deprecation.deprecated_endpoints("regex_replace") 78@dispatch.add_dispatch_support 79def regex_replace(input, pattern, rewrite, replace_global=True, name=None): 80 r"""Replace elements of `input` matching regex `pattern` with `rewrite`. 81 82 >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>", 83 ... "<[^>]+>", " ") 84 <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags. contains html '> 85 86 Args: 87 input: string `Tensor`, the source strings to process. 88 pattern: string or scalar string `Tensor`, regular expression to use, 89 see more details at https://github.com/google/re2/wiki/Syntax 90 rewrite: string or scalar string `Tensor`, value to use in match 91 replacement, supports backslash-escaped digits (\1 to \9) can be to insert 92 text matching corresponding parenthesized group. 93 replace_global: `bool`, if `True` replace all non-overlapping matches, 94 else replace only the first match. 95 name: A name for the operation (optional). 96 97 Returns: 98 string `Tensor` of the same shape as `input` with specified replacements. 99 """ 100 if (isinstance(pattern, util_compat.bytes_or_text_types) and 101 isinstance(rewrite, util_compat.bytes_or_text_types)): 102 # When `pattern` and `rewrite` are static through the life of the op we can 103 # use a version which performs the expensive regex compilation once at 104 # creation time. 105 return gen_string_ops.static_regex_replace( 106 input=input, pattern=pattern, 107 rewrite=rewrite, replace_global=replace_global, 108 name=name) 109 return gen_string_ops.regex_replace( 110 input=input, pattern=pattern, 111 rewrite=rewrite, replace_global=replace_global, 112 name=name) 113 114 115@tf_export("strings.format") 116@dispatch.add_dispatch_support 117def string_format(template, inputs, placeholder="{}", summarize=3, name=None): 118 r"""Formats a string template using a list of tensors. 119 120 Formats a string template using a list of tensors, abbreviating tensors by 121 only printing the first and last `summarize` elements of each dimension 122 (recursively). If formatting only one tensor into a template, the tensor does 123 not have to be wrapped in a list. 124 125 Example: 126 Formatting a single-tensor template: 127 128 >>> tensor = tf.range(5) 129 >>> tf.strings.format("tensor: {}, suffix", tensor) 130 <tf.Tensor: shape=(), dtype=string, numpy=b'tensor: [0 1 2 3 4], suffix'> 131 132 Formatting a multi-tensor template: 133 134 >>> tensor_a = tf.range(2) 135 >>> tensor_b = tf.range(1, 4, 2) 136 >>> tf.strings.format("a: {}, b: {}, suffix", (tensor_a, tensor_b)) 137 <tf.Tensor: shape=(), dtype=string, numpy=b'a: [0 1], b: [1 3], suffix'> 138 139 140 Args: 141 template: A string template to format tensor values into. 142 inputs: A list of `Tensor` objects, or a single Tensor. 143 The list of tensors to format into the template string. If a solitary 144 tensor is passed in, the input tensor will automatically be wrapped as a 145 list. 146 placeholder: An optional `string`. Defaults to `{}`. 147 At each placeholder occurring in the template, a subsequent tensor 148 will be inserted. 149 summarize: An optional `int`. Defaults to `3`. 150 When formatting the tensors, show the first and last `summarize` 151 entries of each tensor dimension (recursively). If set to -1, all 152 elements of the tensor will be shown. 153 name: A name for the operation (optional). 154 155 Returns: 156 A scalar `Tensor` of type `string`. 157 158 Raises: 159 ValueError: if the number of placeholders does not match the number of 160 inputs. 161 """ 162 # If there is only one tensor to format, we will automatically wrap it in a 163 # list to simplify the user experience 164 if tensor_util.is_tf_type(inputs): 165 inputs = [inputs] 166 if template.count(placeholder) != len(inputs): 167 raise ValueError("%s placeholder(s) in template does not match %s tensor(s)" 168 " provided as input" % (template.count(placeholder), 169 len(inputs))) 170 171 return gen_string_ops.string_format(inputs, 172 template=template, 173 placeholder=placeholder, 174 summarize=summarize, 175 name=name) 176 177 178# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 179# defines a wrapper for this function. 180def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name 181 """Split elements of `source` based on `delimiter` into a `SparseTensor`. 182 183 Let N be the size of source (typically N will be the batch size). Split each 184 element of `source` based on `delimiter` and return a `SparseTensor` 185 containing the split tokens. Empty tokens are ignored. 186 187 If `sep` is an empty string, each element of the `source` is split 188 into individual strings, each containing one byte. (This includes splitting 189 multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is 190 treated as a set of delimiters with each considered a potential split point. 191 192 For example: 193 N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output 194 will be 195 196 st.indices = [0, 0; 197 0, 1; 198 1, 0; 199 1, 1; 200 1, 2] 201 st.shape = [2, 3] 202 st.values = ['hello', 'world', 'a', 'b', 'c'] 203 204 Args: 205 source: `1-D` string `Tensor`, the strings to split. 206 sep: `0-D` string `Tensor`, the delimiter character, the string should 207 be length 0 or 1. Default is ' '. 208 skip_empty: A `bool`. If `True`, skip the empty strings from the result. 209 delimiter: deprecated alias for `sep`. 210 211 Raises: 212 ValueError: If delimiter is not a string. 213 214 Returns: 215 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 216 The first column of the indices corresponds to the row in `source` and the 217 second column corresponds to the index of the split component in this row. 218 """ 219 delimiter = deprecation.deprecated_argument_lookup( 220 "sep", sep, "delimiter", delimiter) 221 222 if delimiter is None: 223 delimiter = " " 224 delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) 225 source = ops.convert_to_tensor(source, dtype=dtypes.string) 226 227 indices, values, shape = gen_string_ops.string_split( 228 source, delimiter=delimiter, skip_empty=skip_empty) 229 indices.set_shape([None, 2]) 230 values.set_shape([None]) 231 shape.set_shape([2]) 232 return sparse_tensor.SparseTensor(indices, values, shape) 233 234 235# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 236# defines a wrapper for this function. 237def string_split_v2(source, sep=None, maxsplit=-1): 238 """Split elements of `source` based on `sep` into a `SparseTensor`. 239 240 Let N be the size of source (typically N will be the batch size). Split each 241 element of `source` based on `sep` and return a `SparseTensor` 242 containing the split tokens. Empty tokens are ignored. 243 244 For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', 245 then the output will be 246 247 st.indices = [0, 0; 248 0, 1; 249 1, 0; 250 1, 1; 251 1, 2] 252 st.shape = [2, 3] 253 st.values = ['hello', 'world', 'a', 'b', 'c'] 254 255 If `sep` is given, consecutive delimiters are not grouped together and are 256 deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and 257 sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty 258 string, consecutive whitespace are regarded as a single separator, and the 259 result will contain no empty strings at the start or end if the string has 260 leading or trailing whitespace. 261 262 Note that the above mentioned behavior matches python's str.split. 263 264 Args: 265 source: `1-D` string `Tensor`, the strings to split. 266 sep: `0-D` string `Tensor`, the delimiter character. 267 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. 268 269 Raises: 270 ValueError: If sep is not a string. 271 272 Returns: 273 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 274 The first column of the indices corresponds to the row in `source` and the 275 second column corresponds to the index of the split component in this row. 276 """ 277 if sep is None: 278 sep = "" 279 sep = ops.convert_to_tensor(sep, dtype=dtypes.string) 280 source = ops.convert_to_tensor(source, dtype=dtypes.string) 281 282 indices, values, shape = gen_string_ops.string_split_v2( 283 source, sep=sep, maxsplit=maxsplit) 284 indices.set_shape([None, 2]) 285 values.set_shape([None]) 286 shape.set_shape([2]) 287 return sparse_tensor.SparseTensor(indices, values, shape) 288 289 290def _reduce_join_reduction_dims(x, axis): 291 """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise.""" 292 if axis is not None: 293 return axis 294 else: 295 # Fast path: avoid creating Rank and Range ops if ndims is known. 296 if x.get_shape().ndims is not None: 297 return constant_op.constant( 298 np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32) 299 300 # Otherwise, we rely on Range and Rank to do the right thing at run-time. 301 return math_ops.range(array_ops.rank(x) - 1, -1, -1) 302 303 304@tf_export(v1=["strings.reduce_join", "reduce_join"]) 305@dispatch.add_dispatch_support 306@deprecation.deprecated_args(None, 307 "keep_dims is deprecated, use keepdims instead", 308 "keep_dims") 309@deprecation.deprecated_endpoints("reduce_join") 310def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring 311 keep_dims=None, 312 separator="", 313 name=None, 314 reduction_indices=None, 315 keepdims=None): 316 keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, 317 "keep_dims", keep_dims) 318 if keep_dims is None: 319 keep_dims = False 320 axis = deprecation.deprecated_argument_lookup("axis", axis, 321 "reduction_indices", 322 reduction_indices) 323 return reduce_join_v2( 324 inputs=inputs, 325 axis=axis, 326 keepdims=keepdims, 327 separator=separator, 328 name=name) 329 330 331@tf_export("strings.reduce_join", v1=[]) 332@dispatch.add_dispatch_support 333def reduce_join_v2( # pylint: disable=missing-docstring 334 inputs, 335 axis=None, 336 keepdims=False, 337 separator="", 338 name=None): 339 """Joins all strings into a single string, or joins along an axis. 340 341 This is the reduction operation for the elementwise `tf.strings.join` op. 342 343 >>> tf.strings.reduce_join([['abc','123'], 344 ... ['def','456']]).numpy() 345 b'abc123def456' 346 >>> tf.strings.reduce_join([['abc','123'], 347 ... ['def','456']], axis=-1).numpy() 348 array([b'abc123', b'def456'], dtype=object) 349 >>> tf.strings.reduce_join([['abc','123'], 350 ... ['def','456']], 351 ... axis=-1, 352 ... separator=" ").numpy() 353 array([b'abc 123', b'def 456'], dtype=object) 354 355 Args: 356 inputs: A `tf.string` tensor. 357 axis: Which axis to join along. The default behavior is to join all 358 elements, producing a scalar. 359 keepdims: If true, retains reduced dimensions with length 1. 360 separator: a string added between each string being joined. 361 name: A name for the operation (optional). 362 363 Returns: 364 A `tf.string` tensor. 365 """ 366 with ops.name_scope(None, "ReduceJoin", [inputs, axis]): 367 inputs_t = ops.convert_to_tensor(inputs) 368 axis = _reduce_join_reduction_dims(inputs_t, axis) 369 return gen_string_ops.reduce_join( 370 inputs=inputs_t, 371 reduction_indices=axis, 372 keep_dims=keepdims, 373 separator=separator, 374 name=name) 375 376reduce_join.__doc__ = reduce_join_v2.__doc__ 377 378 379# This wrapper provides backwards compatibility for code that predates the 380# unit argument and that passed 'name' as a positional argument. 381@tf_export(v1=["strings.length"]) 382@dispatch.add_dispatch_support 383def string_length(input, name=None, unit="BYTE"): 384 """Computes the length of each string given in the input tensor. 385 386 >>> strings = tf.constant(['Hello','TensorFlow', '']) 387 >>> tf.strings.length(strings).numpy() # default counts bytes 388 array([ 5, 10, 4], dtype=int32) 389 >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy() 390 array([ 5, 10, 1], dtype=int32) 391 392 Args: 393 input: A `Tensor` of type `string`. The strings for which to compute the 394 length for each element. 395 name: A name for the operation (optional). 396 unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to 397 `"BYTE"`. The unit that is counted to compute string length. One of: 398 `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for 399 the number of UTF-8 encoded Unicode code points in each string). Results 400 are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain 401 structurally valid UTF-8. 402 403 Returns: 404 A `Tensor` of type `int32`, containing the length of the input string in 405 the same element of the input tensor. 406 """ 407 return gen_string_ops.string_length(input, unit=unit, name=name) 408 409 410@tf_export("strings.length", v1=[]) 411@dispatch.add_dispatch_support 412def string_length_v2(input, unit="BYTE", name=None): 413 return gen_string_ops.string_length(input, unit=unit, name=name) 414 415 416string_length_v2.__doc__ = gen_string_ops.string_length.__doc__ 417 418 419@tf_export(v1=["substr"]) 420@dispatch.add_dispatch_support 421@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") 422def substr_deprecated(input, pos, len, name=None, unit="BYTE"): 423 return substr(input, pos, len, name=name, unit=unit) 424 425substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ 426 427 428@tf_export(v1=["strings.substr"]) 429@dispatch.add_dispatch_support 430def substr(input, pos, len, name=None, unit="BYTE"): 431 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 432 433substr.__doc__ = gen_string_ops.substr.__doc__ 434 435 436@tf_export("strings.substr", v1=[]) 437@dispatch.add_dispatch_support 438def substr_v2(input, pos, len, unit="BYTE", name=None): 439 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 440 441substr_v2.__doc__ = gen_string_ops.substr.__doc__ 442 443 444ops.NotDifferentiable("RegexReplace") 445ops.NotDifferentiable("StringToHashBucket") 446ops.NotDifferentiable("StringToHashBucketFast") 447ops.NotDifferentiable("StringToHashBucketStrong") 448ops.NotDifferentiable("ReduceJoin") 449ops.NotDifferentiable("StringJoin") 450ops.NotDifferentiable("StringSplit") 451ops.NotDifferentiable("AsString") 452ops.NotDifferentiable("EncodeBase64") 453ops.NotDifferentiable("DecodeBase64") 454 455 456@tf_export("strings.to_number", v1=[]) 457@dispatch.add_dispatch_support 458def string_to_number(input, out_type=dtypes.float32, name=None): 459 r"""Converts each string in the input Tensor to the specified numeric type. 460 461 (Note that int32 overflow results in an error while float overflow 462 results in a rounded value.) 463 464 Examples: 465 466 >>> tf.strings.to_number("1.55") 467 <tf.Tensor: shape=(), dtype=float32, numpy=1.55> 468 >>> tf.strings.to_number("3", tf.int32) 469 <tf.Tensor: shape=(), dtype=int32, numpy=3> 470 471 Args: 472 input: A `Tensor` of type `string`. 473 out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32, 474 tf.int64`. Defaults to `tf.float32`. 475 The numeric type to interpret each string in `string_tensor` as. 476 name: A name for the operation (optional). 477 478 Returns: 479 A `Tensor` of type `out_type`. 480 """ 481 return gen_parsing_ops.string_to_number(input, out_type, name) 482 483 484@tf_export(v1=["strings.to_number", "string_to_number"]) 485@dispatch.add_dispatch_support 486def string_to_number_v1( 487 string_tensor=None, 488 out_type=dtypes.float32, 489 name=None, 490 input=None): 491 string_tensor = deprecation.deprecated_argument_lookup( 492 "input", input, "string_tensor", string_tensor) 493 return gen_parsing_ops.string_to_number(string_tensor, out_type, name) 494 495string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__ 496 497 498@tf_export("strings.to_hash_bucket", v1=[]) 499@dispatch.add_dispatch_support 500def string_to_hash_bucket(input, num_buckets, name=None): 501 # pylint: disable=line-too-long 502 r"""Converts each string in the input Tensor to its hash mod by a number of buckets. 503 504 The hash function is deterministic on the content of the string within the 505 process. 506 507 Note that the hash function may change from time to time. 508 This functionality will be deprecated and it's recommended to use 509 `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`. 510 511 Examples: 512 513 >>> tf.strings.to_hash_bucket(["Hello", "TensorFlow", "2.x"], 3) 514 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 0, 1])> 515 516 Args: 517 input: A `Tensor` of type `string`. 518 num_buckets: An `int` that is `>= 1`. The number of buckets. 519 name: A name for the operation (optional). 520 521 Returns: 522 A `Tensor` of type `int64`. 523 """ 524 # pylint: enable=line-too-long 525 return gen_string_ops.string_to_hash_bucket(input, num_buckets, name) 526 527 528@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) 529@dispatch.add_dispatch_support 530def string_to_hash_bucket_v1( 531 string_tensor=None, 532 num_buckets=None, 533 name=None, 534 input=None): 535 string_tensor = deprecation.deprecated_argument_lookup( 536 "input", input, "string_tensor", string_tensor) 537 return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name) 538 539string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__ 540 541 542@tf_export("strings.join", v1=["strings.join", "string_join"]) 543@dispatch.add_dispatch_support 544@deprecation.deprecated_endpoints("string_join") 545@dispatch.add_dispatch_support 546def string_join(inputs, separator="", name=None): 547 """Perform element-wise concatenation of a list of string tensors. 548 549 Given a list of string tensors of same shape, performs element-wise 550 concatenation of the strings of the same index in all tensors. 551 552 553 >>> tf.strings.join(['abc','def']).numpy() 554 b'abcdef' 555 >>> tf.strings.join([['abc','123'], 556 ... ['def','456'], 557 ... ['ghi','789']]).numpy() 558 array([b'abcdefghi', b'123456789'], dtype=object) 559 >>> tf.strings.join([['abc','123'], 560 ... ['def','456']], 561 ... separator=" ").numpy() 562 array([b'abc def', b'123 456'], dtype=object) 563 564 The reduction version of this elementwise operation is 565 `tf.strings.reduce_join` 566 567 Args: 568 inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype. 569 separator: A string added between each string being joined. 570 name: A name for the operation (optional). 571 572 Returns: 573 A `tf.string` tensor. 574 """ 575 return gen_string_ops.string_join(inputs, separator=separator, name=name) 576