1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Operations for working with string Tensors.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.compat import compat 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_parsing_ops 32from tensorflow.python.ops import gen_string_ops 33from tensorflow.python.ops import math_ops 34 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37# pylint: disable=g-bad-import-order 38from tensorflow.python.ops.gen_string_ops import * 39from tensorflow.python.util import compat as util_compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43# pylint: enable=g-bad-import-order 44# pylint: enable=wildcard-import 45 46 47# pylint: disable=redefined-builtin 48@tf_export("strings.regex_full_match") 49@dispatch.add_dispatch_support 50def regex_full_match(input, pattern, name=None): 51 r"""Match elements of `input` with regex `pattern`. 52 53 Args: 54 input: string `Tensor`, the source strings to process. 55 pattern: string or scalar string `Tensor`, regular expression to use, 56 see more details at https://github.com/google/re2/wiki/Syntax 57 name: Name of the op. 58 59 Returns: 60 bool `Tensor` of the same shape as `input` with match results. 61 """ 62 # TODO(b/112455102): Remove compat.forward_compatible once past the horizon. 63 if not compat.forward_compatible(2018, 11, 10): 64 return gen_string_ops.regex_full_match( 65 input=input, pattern=pattern, name=name) 66 if isinstance(pattern, util_compat.bytes_or_text_types): 67 # When `pattern` is static through the life of the op we can 68 # use a version which performs the expensive regex compilation once at 69 # creation time. 70 return gen_string_ops.static_regex_full_match( 71 input=input, pattern=pattern, name=name) 72 return gen_string_ops.regex_full_match( 73 input=input, pattern=pattern, name=name) 74 75regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ 76 77 78@tf_export( 79 "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) 80@deprecation.deprecated_endpoints("regex_replace") 81@dispatch.add_dispatch_support 82def regex_replace(input, pattern, rewrite, replace_global=True, name=None): 83 r"""Replace elements of `input` matching regex `pattern` with `rewrite`. 84 85 Args: 86 input: string `Tensor`, the source strings to process. 87 pattern: string or scalar string `Tensor`, regular expression to use, 88 see more details at https://github.com/google/re2/wiki/Syntax 89 rewrite: string or scalar string `Tensor`, value to use in match 90 replacement, supports backslash-escaped digits (\1 to \9) can be to insert 91 text matching corresponding parenthesized group. 92 replace_global: `bool`, if `True` replace all non-overlapping matches, 93 else replace only the first match. 94 name: A name for the operation (optional). 95 96 Returns: 97 string `Tensor` of the same shape as `input` with specified replacements. 98 """ 99 if (isinstance(pattern, util_compat.bytes_or_text_types) and 100 isinstance(rewrite, util_compat.bytes_or_text_types)): 101 # When `pattern` and `rewrite` are static through the life of the op we can 102 # use a version which performs the expensive regex compilation once at 103 # creation time. 104 return gen_string_ops.static_regex_replace( 105 input=input, pattern=pattern, 106 rewrite=rewrite, replace_global=replace_global, 107 name=name) 108 return gen_string_ops.regex_replace( 109 input=input, pattern=pattern, 110 rewrite=rewrite, replace_global=replace_global, 111 name=name) 112 113 114@tf_export("strings.format") 115def string_format(template, inputs, placeholder="{}", summarize=3, name=None): 116 r"""Formats a string template using a list of tensors. 117 118 Formats a string template using a list of tensors, abbreviating tensors by 119 only printing the first and last `summarize` elements of each dimension 120 (recursively). If formatting only one tensor into a template, the tensor does 121 not have to be wrapped in a list. 122 123 Example: 124 Formatting a single-tensor template: 125 ```python 126 sess = tf.Session() 127 with sess.as_default(): 128 tensor = tf.range(10) 129 formatted = tf.strings.format("tensor: {}, suffix", tensor) 130 out = sess.run(formatted) 131 expected = "tensor: [0 1 2 ... 7 8 9], suffix" 132 133 assert(out.decode() == expected) 134 ``` 135 136 Formatting a multi-tensor template: 137 ```python 138 sess = tf.Session() 139 with sess.as_default(): 140 tensor_one = tf.reshape(tf.range(100), [10, 10]) 141 tensor_two = tf.range(10) 142 formatted = tf.strings.format("first: {}, second: {}, suffix", 143 (tensor_one, tensor_two)) 144 145 out = sess.run(formatted) 146 expected = ("first: [[0 1 2 ... 7 8 9]\n" 147 " [10 11 12 ... 17 18 19]\n" 148 " [20 21 22 ... 27 28 29]\n" 149 " ...\n" 150 " [70 71 72 ... 77 78 79]\n" 151 " [80 81 82 ... 87 88 89]\n" 152 " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix") 153 154 assert(out.decode() == expected) 155 ``` 156 157 Args: 158 template: A string template to format tensor values into. 159 inputs: A list of `Tensor` objects, or a single Tensor. 160 The list of tensors to format into the template string. If a solitary 161 tensor is passed in, the input tensor will automatically be wrapped as a 162 list. 163 placeholder: An optional `string`. Defaults to `{}`. 164 At each placeholder occurring in the template, a subsequent tensor 165 will be inserted. 166 summarize: An optional `int`. Defaults to `3`. 167 When formatting the tensors, show the first and last `summarize` 168 entries of each tensor dimension (recursively). If set to -1, all 169 elements of the tensor will be shown. 170 name: A name for the operation (optional). 171 172 Returns: 173 A scalar `Tensor` of type `string`. 174 175 Raises: 176 ValueError: if the number of placeholders does not match the number of 177 inputs. 178 """ 179 # If there is only one tensor to format, we will automatically wrap it in a 180 # list to simplify the user experience 181 if tensor_util.is_tensor(inputs): 182 inputs = [inputs] 183 if template.count(placeholder) != len(inputs): 184 raise ValueError("%s placeholder(s) in template does not match %s tensor(s)" 185 " provided as input" % (template.count(placeholder), 186 len(inputs))) 187 188 return gen_string_ops.string_format(inputs, 189 template=template, 190 placeholder=placeholder, 191 summarize=summarize, 192 name=name) 193 194 195@tf_export(v1=["string_split"]) 196@deprecation.deprecated_args(None, 197 "delimiter is deprecated, please use sep instead.", 198 "delimiter") 199def string_split(source, sep=None, skip_empty=True, delimiter=None): # pylint: disable=invalid-name 200 """Split elements of `source` based on `delimiter` into a `SparseTensor`. 201 202 Let N be the size of source (typically N will be the batch size). Split each 203 element of `source` based on `delimiter` and return a `SparseTensor` 204 containing the split tokens. Empty tokens are ignored. 205 206 If `sep` is an empty string, each element of the `source` is split 207 into individual strings, each containing one byte. (This includes splitting 208 multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is 209 treated as a set of delimiters with each considered a potential split point. 210 211 For example: 212 N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output 213 will be 214 215 st.indices = [0, 0; 216 0, 1; 217 1, 0; 218 1, 1; 219 1, 2] 220 st.shape = [2, 3] 221 st.values = ['hello', 'world', 'a', 'b', 'c'] 222 223 Args: 224 source: `1-D` string `Tensor`, the strings to split. 225 sep: `0-D` string `Tensor`, the delimiter character, the string should 226 be length 0 or 1. Default is ' '. 227 skip_empty: A `bool`. If `True`, skip the empty strings from the result. 228 delimiter: deprecated alias for `sep`. 229 230 Raises: 231 ValueError: If delimiter is not a string. 232 233 Returns: 234 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 235 The first column of the indices corresponds to the row in `source` and the 236 second column corresponds to the index of the split component in this row. 237 """ 238 delimiter = deprecation.deprecated_argument_lookup( 239 "sep", sep, "delimiter", delimiter) 240 241 if delimiter is None: 242 delimiter = " " 243 delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string) 244 source = ops.convert_to_tensor(source, dtype=dtypes.string) 245 246 indices, values, shape = gen_string_ops.string_split( 247 source, delimiter=delimiter, skip_empty=skip_empty) 248 indices.set_shape([None, 2]) 249 values.set_shape([None]) 250 shape.set_shape([2]) 251 return sparse_tensor.SparseTensor(indices, values, shape) 252 253 254@tf_export("strings.split") 255def string_split_v2(source, sep=None, maxsplit=-1): 256 """Split elements of `source` based on `sep` into a `SparseTensor`. 257 258 Let N be the size of source (typically N will be the batch size). Split each 259 element of `source` based on `sep` and return a `SparseTensor` 260 containing the split tokens. Empty tokens are ignored. 261 262 For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', 263 then the output will be 264 265 st.indices = [0, 0; 266 0, 1; 267 1, 0; 268 1, 1; 269 1, 2] 270 st.shape = [2, 3] 271 st.values = ['hello', 'world', 'a', 'b', 'c'] 272 273 If `sep` is given, consecutive delimiters are not grouped together and are 274 deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and 275 sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty 276 string, consecutive whitespace are regarded as a single separator, and the 277 result will contain no empty strings at the start or end if the string has 278 leading or trailing whitespace. 279 280 Note that the above mentioned behavior matches python's str.split. 281 282 Args: 283 source: `1-D` string `Tensor`, the strings to split. 284 sep: `0-D` string `Tensor`, the delimiter character. 285 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. 286 287 Raises: 288 ValueError: If sep is not a string. 289 290 Returns: 291 A `SparseTensor` of rank `2`, the strings split according to the delimiter. 292 The first column of the indices corresponds to the row in `source` and the 293 second column corresponds to the index of the split component in this row. 294 """ 295 if sep is None: 296 sep = "" 297 sep = ops.convert_to_tensor(sep, dtype=dtypes.string) 298 source = ops.convert_to_tensor(source, dtype=dtypes.string) 299 300 indices, values, shape = gen_string_ops.string_split_v2( 301 source, sep=sep, maxsplit=maxsplit) 302 indices.set_shape([None, 2]) 303 values.set_shape([None]) 304 shape.set_shape([2]) 305 return sparse_tensor.SparseTensor(indices, values, shape) 306 307 308def _reduce_join_reduction_dims(x, axis, reduction_indices): 309 """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None.""" 310 # TODO(aselle): Remove this after deprecation 311 if reduction_indices is not None: 312 if axis is not None: 313 raise ValueError("Can't specify both 'axis' and 'reduction_indices'.") 314 axis = reduction_indices 315 if axis is not None: 316 return axis 317 else: 318 # Fast path: avoid creating Rank and Range ops if ndims is known. 319 if x.get_shape().ndims is not None: 320 return constant_op.constant( 321 np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32) 322 323 # Otherwise, we rely on Range and Rank to do the right thing at run-time. 324 return math_ops.range(array_ops.rank(x) - 1, -1, -1) 325 326 327@tf_export(v1=["strings.reduce_join", "reduce_join"]) 328@deprecation.deprecated_endpoints("reduce_join") 329def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring 330 keep_dims=False, 331 separator="", 332 name=None, 333 reduction_indices=None, 334 keepdims=None): 335 keep_dims = deprecation.deprecated_argument_lookup( 336 "keepdims", keepdims, "keep_dims", keep_dims) 337 inputs_t = ops.convert_to_tensor(inputs) 338 reduction_indices = _reduce_join_reduction_dims( 339 inputs_t, axis, reduction_indices) 340 return gen_string_ops.reduce_join( 341 inputs=inputs_t, 342 reduction_indices=reduction_indices, 343 keep_dims=keep_dims, 344 separator=separator, 345 name=name) 346 347 348@tf_export("strings.reduce_join", v1=[]) 349def reduce_join_v2( # pylint: disable=missing-docstring 350 inputs, 351 axis=None, 352 keepdims=False, 353 separator="", 354 name=None): 355 return reduce_join( 356 inputs, axis, keep_dims=keepdims, separator=separator, name=name) 357 358 359reduce_join.__doc__ = deprecation.rewrite_argument_docstring( 360 gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis") 361reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(", 362 "tf.strings.reduce_join(") 363 364 365# This wrapper provides backwards compatibility for code that predates the 366# unit argument and that passed 'name' as a positional argument. 367@tf_export(v1=["strings.length"]) 368@dispatch.add_dispatch_support 369def string_length(input, name=None, unit="BYTE"): 370 return gen_string_ops.string_length(input, unit=unit, name=name) 371 372 373@tf_export("strings.length", v1=[]) 374@dispatch.add_dispatch_support 375def string_length_v2(input, unit="BYTE", name=None): 376 return string_length(input, name, unit) 377 378 379string_length.__doc__ = gen_string_ops.string_length.__doc__ 380 381 382@tf_export(v1=["substr"]) 383@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") 384def substr_deprecated(input, pos, len, name=None, unit="BYTE"): 385 return substr(input, pos, len, name=name, unit=unit) 386 387substr_deprecated.__doc__ = gen_string_ops.substr.__doc__ 388 389 390@tf_export(v1=["strings.substr"]) 391@dispatch.add_dispatch_support 392def substr(input, pos, len, name=None, unit="BYTE"): 393 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 394 395substr.__doc__ = gen_string_ops.substr.__doc__ 396 397 398@tf_export("strings.substr", v1=[]) 399@dispatch.add_dispatch_support 400def substr_v2(input, pos, len, unit="BYTE", name=None): 401 return gen_string_ops.substr(input, pos, len, unit=unit, name=name) 402 403substr_v2.__doc__ = gen_string_ops.substr.__doc__ 404 405 406ops.NotDifferentiable("RegexReplace") 407ops.NotDifferentiable("StringToHashBucket") 408ops.NotDifferentiable("StringToHashBucketFast") 409ops.NotDifferentiable("StringToHashBucketStrong") 410ops.NotDifferentiable("ReduceJoin") 411ops.NotDifferentiable("StringJoin") 412ops.NotDifferentiable("StringSplit") 413ops.NotDifferentiable("AsString") 414ops.NotDifferentiable("EncodeBase64") 415ops.NotDifferentiable("DecodeBase64") 416 417 418@tf_export("strings.to_number", v1=[]) 419@dispatch.add_dispatch_support 420def string_to_number(input, out_type=dtypes.float32, name=None): 421 r"""Converts each string in the input Tensor to the specified numeric type. 422 423 (Note that int32 overflow results in an error while float overflow 424 results in a rounded value.) 425 426 Args: 427 input: A `Tensor` of type `string`. 428 out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32, 429 tf.int64`. Defaults to `tf.float32`. 430 The numeric type to interpret each string in `string_tensor` as. 431 name: A name for the operation (optional). 432 433 Returns: 434 A `Tensor` of type `out_type`. 435 """ 436 return gen_parsing_ops.string_to_number(input, out_type, name) 437 438 439@tf_export(v1=["strings.to_number", "string_to_number"]) 440def string_to_number_v1( 441 string_tensor=None, 442 out_type=dtypes.float32, 443 name=None, 444 input=None): 445 string_tensor = deprecation.deprecated_argument_lookup( 446 "input", input, "string_tensor", string_tensor) 447 return gen_parsing_ops.string_to_number(string_tensor, out_type, name) 448 449string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__ 450 451 452@tf_export("strings.to_hash_bucket", v1=[]) 453@dispatch.add_dispatch_support 454def string_to_hash_bucket(input, num_buckets, name=None): 455 # pylint: disable=line-too-long 456 r"""Converts each string in the input Tensor to its hash mod by a number of buckets. 457 458 The hash function is deterministic on the content of the string within the 459 process. 460 461 Note that the hash function may change from time to time. 462 This functionality will be deprecated and it's recommended to use 463 `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`. 464 465 Args: 466 input: A `Tensor` of type `string`. 467 num_buckets: An `int` that is `>= 1`. The number of buckets. 468 name: A name for the operation (optional). 469 470 Returns: 471 A `Tensor` of type `int64`. 472 """ 473 # pylint: enable=line-too-long 474 return gen_string_ops.string_to_hash_bucket(input, num_buckets, name) 475 476 477@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) 478def string_to_hash_bucket_v1( 479 string_tensor=None, 480 num_buckets=None, 481 name=None, 482 input=None): 483 string_tensor = deprecation.deprecated_argument_lookup( 484 "input", input, "string_tensor", string_tensor) 485 return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name) 486 487string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__ 488