1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16 17"""Operations for working with string Tensors.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_parsing_ops 32from tensorflow.python.ops import gen_string_ops 33from tensorflow.python.ops import math_ops 34 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37# pylint: disable=g-bad-import-order 38from tensorflow.python.ops.gen_string_ops import * 39from tensorflow.python.util import compat as util_compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43# pylint: enable=g-bad-import-order 44# pylint: enable=wildcard-import 45 46 47# pylint: disable=redefined-builtin 48@tf_export("strings.regex_full_match") 49@dispatch.add_dispatch_support 50def regex_full_match(input, pattern, name=None): 51 r"""Match elements of `input` with regex `pattern`. 52 53 Args: 54 input: string `Tensor`, the source strings to process. 55 pattern: string or scalar string `Tensor`, regular expression to use, 56 see more details at https://github.com/google/re2/wiki/Syntax 57 name: Name of the op. 58 59 Returns: 60 bool `Tensor` of the same shape as `input` with match results. 61 """ 62 if isinstance(pattern, util_compat.bytes_or_text_types): 63 # When `pattern` is static through the life of the op we can 64 # use a version which performs the expensive regex compilation once at 65 # creation time. 66 return gen_string_ops.static_regex_full_match( 67 input=input, pattern=pattern, name=name) 68 return gen_string_ops.regex_full_match( 69 input=input, pattern=pattern, name=name) 70 71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ 72 73 74@tf_export( 75 "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) 76@deprecation.deprecated_endpoints("regex_replace") 77@dispatch.add_dispatch_support 78def regex_replace(input, pattern, rewrite, replace_global=True, name=None): 79 r"""Replace elements of `input` matching regex `pattern` with `rewrite`. 80 81 >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>", 82 ... "<[^>]+>", " ") 83 <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags. contains html '> 84 85 Args: 86 input: string `Tensor`, the source strings to process. 87 pattern: string or scalar string `Tensor`, regular expression to use, 88 see more details at https://github.com/google/re2/wiki/Syntax 89 rewrite: string or scalar string `Tensor`, value to use in match 90 replacement, supports backslash-escaped digits (\1 to \9) can be to insert 91 text matching corresponding parenthesized group. 92 replace_global: `bool`, if `True` replace all non-overlapping matches, 93 else replace only the first match. 94 name: A name for the operation (optional). 95 96 Returns: 97 string `Tensor` of the same shape as `input` with specified replacements. 98 """ 99 if (isinstance(pattern, util_compat.bytes_or_text_types) and 100 isinstance(rewrite, util_compat.bytes_or_text_types)): 101 # When `pattern` and `rewrite` are static through the life of the op we can 102 # use a version which performs the expensive regex compilation once at 103 # creation time. 104 return gen_string_ops.static_regex_replace( 105 input=input, pattern=pattern, 106 rewrite=rewrite, replace_global=replace_global, 107 name=name) 108 return gen_string_ops.regex_replace( 109 input=input, pattern=pattern, 110 rewrite=rewrite, replace_global=replace_global, 111 name=name) 112 113 114@tf_export("strings.format") 115def string_format(template, inputs, placeholder="{}", summarize=3, name=None): 116 r"""Formats a string template using a list of tensors. 117 118 Formats a string template using a list of tensors, abbreviating tensors by 119 only printing the first and last `summarize` elements of each dimension 120 (recursively). If formatting only one tensor into a template, the tensor does 121 not have to be wrapped in a list. 122 123 Example: 124 Formatting a single-tensor template: 125 ```python 126 sess = tf.compat.v1.Session() 127 with sess.as_default(): 128 tensor = tf.range(10) 129 formatted = tf.strings.format("tensor: {}, suffix", tensor) 130 out = sess.run(formatted) 131 expected = "tensor: [0 1 2 ... 7 8 9], suffix" 132 133 assert(out.decode() == expected) 134 ``` 135 136 Formatting a multi-tensor template: 137 ```python 138 sess = tf.compat.v1.Session() 139 with sess.as_default(): 140 tensor_one = tf.reshape(tf.range(100), [10, 10]) 141 tensor_two = tf.range(10) 142 formatted = tf.strings.format("first: {}, second: {}, suffix", 143 (tensor_one, tensor_two)) 144 145 out = sess.run(formatted) 146 expected = ("first: [[0 1 2 ... 7 8 9]\n" 147 " [10 11 12 ... 17 18 19]\n" 148 " [20 21 22 ... 27 28 29]\n" 149 " ...\n" 150 " [70 71 72 ... 77 78 79]\n" 151 " [80 81 82 ... 87 88 89]\n" 152 " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix") 153 154 assert(out.decode() == expected) 155 ``` 156 157 Args: 158 template: A string template to format tensor values into. 159 inputs: A list of `Tensor` objects, or a single Tensor. 160 The list of tensors to format into the template string. If a solitary 161 tensor is passed in, the input tensor will automatically be wrapped as a 162 list. 163 placeholder: An optional `string`. Defaults to `{}`. 164 At each placeholder occurring in the template, a subsequent tensor 165 will be inserted. 166 summarize: An optional `int`. Defaults to `3`. 167 When formatting the tensors, show the first and last `summarize` 168 entries of each tensor dimension (recursively). If set to -1, all 169 elements of the tensor will be shown. 170 name: A name for the operation (optional). 171 172 Returns: 173 A scalar `Tensor` of type `string`. 174 175 Raises: 176 ValueError: if the number of placeholders does not match the number of 177 inputs. 178 """ 179 # If there is only one tensor to format, we will automatically wrap it in a 180 # list to simplify the user experience 181 if tensor_util.is_tensor(inputs): 182 inputs = [inputs] 183 if template.count(placeholder) != len(inputs): 184 raise ValueError("%s placeholder(s) in template does not match %s tensor(s)" 185 " provided as input" % (template.count(placeholder), 186 len(inputs))) 187 188 return gen_string_ops.string_format(inputs, 189 template=template, 190 placeholder=placeholder, 191 summarize=summarize, 192 name=name) 193 194 195# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 196# defines a wrapper for this function. 197def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name 198 """Split elements of `source` based on `delimiter` into a `SparseTensor`. 199 200 Let N be the size of source (typically N will be the batch size). Split each 201 element of `source` based on `delimiter` and return a `SparseTensor` 202 containing the split tokens. Empty tokens are ignored. 203 204 If `sep` is an empty string, each element of the `source` is split 205 into individual strings, each containing one byte. (This includes splitting 206 multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is 207 treated as a set of delimiters with each considered a potential split point. 208 209 For example: 210 N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output 211 will be 212 213 st.indices = [0, 0; 214 0, 1; 215 1, 0; 216 1, 1; 217 1, 2] 218 st.shape = [2, 3] 219 st.values = ['hello', 'world', 'a', 'b', 'c'] 220 221 Args: 222 source: `1-D` string `Tensor`, the strings to split. 223 sep: `0-D` string `Tensor`, the delimiter character, the string should 224 be length 0 or 1. Default is ' '. 225 skip_empty: A `bool`. If `True`, skip the empty strings from the result. 226 delimiter: deprecated alias for `sep`. 227 228 Raises: 229 ValueError: If delimiter is not a string. 230 231 Returns: 232 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 233 The first column of the indices corresponds to the row in `source` and the 234 second column corresponds to the index of the split component in this row. 235 """ 236 delimiter = deprecation.deprecated_argument_lookup( 237 "sep", sep, "delimiter", delimiter) 238 239 if delimiter is None: 240 delimiter = " " 241 delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) 242 source = ops.convert_to_tensor(source, dtype=dtypes.string) 243 244 indices, values, shape = gen_string_ops.string_split( 245 source, delimiter=delimiter, skip_empty=skip_empty) 246 indices.set_shape([None, 2]) 247 values.set_shape([None]) 248 shape.set_shape([2]) 249 return sparse_tensor.SparseTensor(indices, values, shape) 250 251 252# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which 253# defines a wrapper for this function. 254def string_split_v2(source, sep=None, maxsplit=-1): 255 """Split elements of `source` based on `sep` into a `SparseTensor`. 256 257 Let N be the size of source (typically N will be the batch size). Split each 258 element of `source` based on `sep` and return a `SparseTensor` 259 containing the split tokens. Empty tokens are ignored. 260 261 For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', 262 then the output will be 263 264 st.indices = [0, 0; 265 0, 1; 266 1, 0; 267 1, 1; 268 1, 2] 269 st.shape = [2, 3] 270 st.values = ['hello', 'world', 'a', 'b', 'c'] 271 272 If `sep` is given, consecutive delimiters are not grouped together and are 273 deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and 274 sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty 275 string, consecutive whitespace are regarded as a single separator, and the 276 result will contain no empty strings at the start or end if the string has 277 leading or trailing whitespace. 278 279 Note that the above mentioned behavior matches python's str.split. 280 281 Args: 282 source: `1-D` string `Tensor`, the strings to split. 283 sep: `0-D` string `Tensor`, the delimiter character. 284 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. 285 286 Raises: 287 ValueError: If sep is not a string. 288 289 Returns: 290 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 291 The first column of the indices corresponds to the row in `source` and the 292 second column corresponds to the index of the split component in this row. 293 """ 294 if sep is None: 295 sep = "" 296 sep = ops.convert_to_tensor(sep, dtype=dtypes.string) 297 source = ops.convert_to_tensor(source, dtype=dtypes.string) 298 299 indices, values, shape = gen_string_ops.string_split_v2( 300 source, sep=sep, maxsplit=maxsplit) 301 indices.set_shape([None, 2]) 302 values.set_shape([None]) 303 shape.set_shape([2]) 304 return sparse_tensor.SparseTensor(indices, values, shape) 305 306 307def _reduce_join_reduction_dims(x, axis): 308 """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise.""" 309 if axis is not None: 310 return axis 311 else: 312 # Fast path: avoid creating Rank and Range ops if ndims is known. 313 if x.get_shape().ndims is not None: 314 return constant_op.constant( 315 np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32) 316 317 # Otherwise, we rely on Range and Rank to do the right thing at run-time. 318 return math_ops.range(array_ops.rank(x) - 1, -1, -1) 319 320 321@tf_export(v1=["strings.reduce_join", "reduce_join"]) 322@deprecation.deprecated_args(None, 323 "keep_dims is deprecated, use keepdims instead", 324 "keep_dims") 325@deprecation.deprecated_endpoints("reduce_join") 326def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring 327 keep_dims=None, 328 separator="", 329 name=None, 330 reduction_indices=None, 331 keepdims=None): 332 keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, 333 "keep_dims", keep_dims) 334 if keep_dims is None: 335 keep_dims = False 336 axis = deprecation.deprecated_argument_lookup("axis", axis, 337 "reduction_indices", 338 reduction_indices) 339 return reduce_join_v2( 340 inputs=inputs, 341 axis=axis, 342 keepdims=keepdims, 343 separator=separator, 344 name=name) 345 346 347@tf_export("strings.reduce_join", v1=[]) 348@dispatch.add_dispatch_support 349def reduce_join_v2( # pylint: disable=missing-docstring 350 inputs, 351 axis=None, 352 keepdims=False, 353 separator="", 354 name=None): 355 """Joins all strings into a single string, or joins along an axis. 356 357 >>> tf.strings.reduce_join([['abc','123'], 358 ... ['def','456']]).numpy() 359 b'abc123def456' 360 >>> tf.strings.reduce_join([['abc','123'], 361 ... ['def','456']], axis=-1).numpy() 362 array([b'abc123', b'def456'], dtype=object) 363 >>> tf.strings.reduce_join([['abc','123'], 364 ... ['def','456']], 365 ... axis=-1, 366 ... separator=" ").numpy() 367 array([b'abc 123', b'def 456'], dtype=object) 368 369 Args: 370 inputs: A `tf.string` tensor. 371 axis: Which axis to join along. The default behavior is to join all 372 elements, producing a scalar. 373 keepdims: If true, retains reduced dimensions with length 1. 374 separator: a string added between each string being joined. 375 name: A name for the operation (optional). 376 377 Returns: 378 A `tf.string` tensor. 379 """ 380 with ops.name_scope(None, "ReduceJoin", [inputs, axis]): 381 inputs_t = ops.convert_to_tensor(inputs) 382 axis = _reduce_join_reduction_dims(inputs_t, axis) 383 return gen_string_ops.reduce_join( 384 inputs=inputs_t, 385 reduction_indices=axis, 386 keep_dims=keepdims, 387 separator=separator, 388 name=name) 389 390reduce_join.__doc__ = reduce_join_v2.__doc__ 391 392 393# This wrapper provides backwards compatibility for code that predates the 394# unit argument and that passed 'name' as a positional argument. 395@tf_export(v1=["strings.length"]) 396@dispatch.add_dispatch_support 397def string_length(input, name=None, unit="BYTE"): 398 """Computes the length of each string given in the input tensor. 399 400 >>> strings = tf.constant(['Hello','TensorFlow', '']) 401 >>> tf.strings.length(strings).numpy() # default counts bytes 402 array([ 5, 10, 4], dtype=int32) 403 >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy() 404 array([ 5, 10, 1], dtype=int32) 405 406 Args: 407 input: A `Tensor` of type `string`. The strings for which to compute the 408 length for each element. 409 name: A name for the operation (optional). 410 unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to 411 `"BYTE"`. The unit that is counted to compute string length. One of: 412 `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for 413 the number of UTF-8 encoded Unicode code points in each string). Results 414 are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain 415 structurally valid UTF-8. 416 417 Returns: 418 A `Tensor` of type `int32`, containing the length of the input string in 419 the same element of the input tensor. 420 """ 421 return gen_string_ops.string_length(input, unit=unit, name=name) 422 423 424@tf_export("strings.length", v1=[]) 425@dispatch.add_dispatch_support 426def string_length_v2(input, unit="BYTE", name=None): 427 return gen_string_ops.string_length(input, unit=unit, name=name) 428 429 430string_length_v2.__doc__ = gen_string_ops.string_length.__doc__ 431 432 433@tf_export(v1=["substr"]) 434@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") 435def substr_deprecated(input, pos, len, name=None, unit="BYTE"): 436 return substr(input, pos, len, name=name, unit=unit) 437 438substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ 439 440 441@tf_export(v1=["strings.substr"]) 442@dispatch.add_dispatch_support 443def substr(input, pos, len, name=None, unit="BYTE"): 444 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 445 446substr.__doc__ = gen_string_ops.substr.__doc__ 447 448 449@tf_export("strings.substr", v1=[]) 450@dispatch.add_dispatch_support 451def substr_v2(input, pos, len, unit="BYTE", name=None): 452 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 453 454substr_v2.__doc__ = gen_string_ops.substr.__doc__ 455 456 457ops.NotDifferentiable("RegexReplace") 458ops.NotDifferentiable("StringToHashBucket") 459ops.NotDifferentiable("StringToHashBucketFast") 460ops.NotDifferentiable("StringToHashBucketStrong") 461ops.NotDifferentiable("ReduceJoin") 462ops.NotDifferentiable("StringJoin") 463ops.NotDifferentiable("StringSplit") 464ops.NotDifferentiable("AsString") 465ops.NotDifferentiable("EncodeBase64") 466ops.NotDifferentiable("DecodeBase64") 467 468 469@tf_export("strings.to_number", v1=[]) 470@dispatch.add_dispatch_support 471def string_to_number(input, out_type=dtypes.float32, name=None): 472 r"""Converts each string in the input Tensor to the specified numeric type. 473 474 (Note that int32 overflow results in an error while float overflow 475 results in a rounded value.) 476 477 Args: 478 input: A `Tensor` of type `string`. 479 out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32, 480 tf.int64`. Defaults to `tf.float32`. 481 The numeric type to interpret each string in `string_tensor` as. 482 name: A name for the operation (optional). 483 484 Returns: 485 A `Tensor` of type `out_type`. 486 """ 487 return gen_parsing_ops.string_to_number(input, out_type, name) 488 489 490@tf_export(v1=["strings.to_number", "string_to_number"]) 491def string_to_number_v1( 492 string_tensor=None, 493 out_type=dtypes.float32, 494 name=None, 495 input=None): 496 string_tensor = deprecation.deprecated_argument_lookup( 497 "input", input, "string_tensor", string_tensor) 498 return gen_parsing_ops.string_to_number(string_tensor, out_type, name) 499 500string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__ 501 502 503@tf_export("strings.to_hash_bucket", v1=[]) 504@dispatch.add_dispatch_support 505def string_to_hash_bucket(input, num_buckets, name=None): 506 # pylint: disable=line-too-long 507 r"""Converts each string in the input Tensor to its hash mod by a number of buckets. 508 509 The hash function is deterministic on the content of the string within the 510 process. 511 512 Note that the hash function may change from time to time. 513 This functionality will be deprecated and it's recommended to use 514 `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`. 515 516 Args: 517 input: A `Tensor` of type `string`. 518 num_buckets: An `int` that is `>= 1`. The number of buckets. 519 name: A name for the operation (optional). 520 521 Returns: 522 A `Tensor` of type `int64`. 523 """ 524 # pylint: enable=line-too-long 525 return gen_string_ops.string_to_hash_bucket(input, num_buckets, name) 526 527 528@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) 529def string_to_hash_bucket_v1( 530 string_tensor=None, 531 num_buckets=None, 532 name=None, 533 input=None): 534 string_tensor = deprecation.deprecated_argument_lookup( 535 "input", input, "string_tensor", string_tensor) 536 return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name) 537 538string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__ 539 540 541@tf_export("strings.join", v1=["strings.join", "string_join"]) 542@deprecation.deprecated_endpoints("string_join") 543@dispatch.add_dispatch_support 544def string_join(inputs, separator="", name=None): 545 """Perform element-wise concatenation of a list of string tensors. 546 547 Given a list of string tensors of same shape, performs element-wise 548 concatenation of the strings of the same index in all tensors. 549 550 551 >>> tf.strings.join(['abc','def']).numpy() 552 b'abcdef' 553 >>> tf.strings.join([['abc','123'], 554 ... ['def','456'], 555 ... ['ghi','789']]).numpy() 556 array([b'abcdefghi', b'123456789'], dtype=object) 557 >>> tf.strings.join([['abc','123'], 558 ... ['def','456']], 559 ... separator=" ").numpy() 560 array([b'abc def', b'123 456'], dtype=object) 561 562 Args: 563 inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype. 564 separator: A string added between each string being joined. 565 name: A name for the operation (optional). 566 567 Returns: 568 A `tf.string` tensor. 569 """ 570 return gen_string_ops.string_join(inputs, separator=separator, name=name) 571