• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17"""Operations for working with string Tensors."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_parsing_ops
32from tensorflow.python.ops import gen_string_ops
33from tensorflow.python.ops import math_ops
34
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37# pylint: disable=g-bad-import-order
38from tensorflow.python.ops.gen_string_ops import *
39from tensorflow.python.util import compat as util_compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43# pylint: enable=g-bad-import-order
44# pylint: enable=wildcard-import
45
46
47# pylint: disable=redefined-builtin
48@tf_export("strings.regex_full_match")
49@dispatch.add_dispatch_support
50def regex_full_match(input, pattern, name=None):
51  r"""Match elements of `input` with regex `pattern`.
52
53  Args:
54    input: string `Tensor`, the source strings to process.
55    pattern: string or scalar string `Tensor`, regular expression to use,
56      see more details at https://github.com/google/re2/wiki/Syntax
57    name: Name of the op.
58
59  Returns:
60    bool `Tensor` of the same shape as `input` with match results.
61  """
62  if isinstance(pattern, util_compat.bytes_or_text_types):
63    # When `pattern` is static through the life of the op we can
64    # use a version which performs the expensive regex compilation once at
65    # creation time.
66    return gen_string_ops.static_regex_full_match(
67        input=input, pattern=pattern, name=name)
68  return gen_string_ops.regex_full_match(
69      input=input, pattern=pattern, name=name)
70
71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
72
73
74@tf_export(
75    "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
76@dispatch.add_dispatch_support
77@deprecation.deprecated_endpoints("regex_replace")
78@dispatch.add_dispatch_support
79def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
80  r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
81
82  >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>",
83  ...                          "<[^>]+>", " ")
84  <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags.  contains html '>
85
86  Args:
87    input: string `Tensor`, the source strings to process.
88    pattern: string or scalar string `Tensor`, regular expression to use,
89      see more details at https://github.com/google/re2/wiki/Syntax
90    rewrite: string or scalar string `Tensor`, value to use in match
91      replacement, supports backslash-escaped digits (\1 to \9) can be to insert
92      text matching corresponding parenthesized group.
93    replace_global: `bool`, if `True` replace all non-overlapping matches,
94      else replace only the first match.
95    name: A name for the operation (optional).
96
97  Returns:
98    string `Tensor` of the same shape as `input` with specified replacements.
99  """
100  if (isinstance(pattern, util_compat.bytes_or_text_types) and
101      isinstance(rewrite, util_compat.bytes_or_text_types)):
102    # When `pattern` and `rewrite` are static through the life of the op we can
103    # use a version which performs the expensive regex compilation once at
104    # creation time.
105    return gen_string_ops.static_regex_replace(
106        input=input, pattern=pattern,
107        rewrite=rewrite, replace_global=replace_global,
108        name=name)
109  return gen_string_ops.regex_replace(
110      input=input, pattern=pattern,
111      rewrite=rewrite, replace_global=replace_global,
112      name=name)
113
114
115@tf_export("strings.format")
116@dispatch.add_dispatch_support
117def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
118  r"""Formats a string template using a list of tensors.
119
120  Formats a string template using a list of tensors, abbreviating tensors by
121  only printing the first and last `summarize` elements of each dimension
122  (recursively). If formatting only one tensor into a template, the tensor does
123  not have to be wrapped in a list.
124
125  Example:
126    Formatting a single-tensor template:
127
128    >>> tensor = tf.range(5)
129    >>> tf.strings.format("tensor: {}, suffix", tensor)
130    <tf.Tensor: shape=(), dtype=string, numpy=b'tensor: [0 1 2 3 4], suffix'>
131
132    Formatting a multi-tensor template:
133
134    >>> tensor_a = tf.range(2)
135    >>> tensor_b = tf.range(1, 4, 2)
136    >>> tf.strings.format("a: {}, b: {}, suffix", (tensor_a, tensor_b))
137    <tf.Tensor: shape=(), dtype=string, numpy=b'a: [0 1], b: [1 3], suffix'>
138
139
140  Args:
141    template: A string template to format tensor values into.
142    inputs: A list of `Tensor` objects, or a single Tensor.
143      The list of tensors to format into the template string. If a solitary
144      tensor is passed in, the input tensor will automatically be wrapped as a
145      list.
146    placeholder: An optional `string`. Defaults to `{}`.
147      At each placeholder occurring in the template, a subsequent tensor
148      will be inserted.
149    summarize: An optional `int`. Defaults to `3`.
150      When formatting the tensors, show the first and last `summarize`
151      entries of each tensor dimension (recursively). If set to -1, all
152      elements of the tensor will be shown.
153    name: A name for the operation (optional).
154
155  Returns:
156    A scalar `Tensor` of type `string`.
157
158  Raises:
159    ValueError: if the number of placeholders does not match the number of
160      inputs.
161  """
162  # If there is only one tensor to format, we will automatically wrap it in a
163  # list to simplify the user experience
164  if tensor_util.is_tf_type(inputs):
165    inputs = [inputs]
166  if template.count(placeholder) != len(inputs):
167    raise ValueError(f"The template expects {template.count(placeholder)} "
168                     f"tensors, but the inputs only has {len(inputs)}. "
169                     "Please ensure the number of placeholders in template "
170                     "matches inputs length.")
171
172  return gen_string_ops.string_format(inputs,
173                                      template=template,
174                                      placeholder=placeholder,
175                                      summarize=summarize,
176                                      name=name)
177
178
179# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
180# defines a wrapper for this function.
181def string_split(source, sep=None, skip_empty=True, delimiter=None):  # pylint: disable=invalid-name
182  """Split elements of `source` based on `delimiter` into a `SparseTensor`.
183
184  Let N be the size of source (typically N will be the batch size). Split each
185  element of `source` based on `delimiter` and return a `SparseTensor`
186  containing the split tokens. Empty tokens are ignored.
187
188  If `sep` is an empty string, each element of the `source` is split
189  into individual strings, each containing one byte. (This includes splitting
190  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
191  treated as a set of delimiters with each considered a potential split point.
192
193  For example:
194  N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
195  will be
196
197  st.indices = [0, 0;
198                0, 1;
199                1, 0;
200                1, 1;
201                1, 2]
202  st.shape = [2, 3]
203  st.values = ['hello', 'world', 'a', 'b', 'c']
204
205  Args:
206    source: `1-D` string `Tensor`, the strings to split.
207    sep: `0-D` string `Tensor`, the delimiter character, the string should
208      be length 0 or 1. Default is ' '.
209    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
210    delimiter: deprecated alias for `sep`.
211
212  Raises:
213    ValueError: If delimiter is not a string.
214
215  Returns:
216    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
217    The first column of the indices corresponds to the row in `source` and the
218    second column corresponds to the index of the split component in this row.
219  """
220  delimiter = deprecation.deprecated_argument_lookup(
221      "sep", sep, "delimiter", delimiter)
222
223  if delimiter is None:
224    delimiter = " "
225  delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
226  source = ops.convert_to_tensor(source, dtype=dtypes.string)
227
228  indices, values, shape = gen_string_ops.string_split(
229      source, delimiter=delimiter, skip_empty=skip_empty)
230  indices.set_shape([None, 2])
231  values.set_shape([None])
232  shape.set_shape([2])
233  return sparse_tensor.SparseTensor(indices, values, shape)
234
235
236# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
237# defines a wrapper for this function.
238def string_split_v2(source, sep=None, maxsplit=-1):
239  """Split elements of `source` based on `sep` into a `SparseTensor`.
240
241  Let N be the size of source (typically N will be the batch size). Split each
242  element of `source` based on `sep` and return a `SparseTensor`
243  containing the split tokens. Empty tokens are ignored.
244
245  For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
246  then the output will be
247
248  st.indices = [0, 0;
249                0, 1;
250                1, 0;
251                1, 1;
252                1, 2]
253  st.shape = [2, 3]
254  st.values = ['hello', 'world', 'a', 'b', 'c']
255
256  If `sep` is given, consecutive delimiters are not grouped together and are
257  deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
258  sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
259  string, consecutive whitespace are regarded as a single separator, and the
260  result will contain no empty strings at the start or end if the string has
261  leading or trailing whitespace.
262
263  Note that the above mentioned behavior matches python's str.split.
264
265  Args:
266    source: `1-D` string `Tensor`, the strings to split.
267    sep: `0-D` string `Tensor`, the delimiter character.
268    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
269
270  Raises:
271    ValueError: If sep is not a string.
272
273  Returns:
274    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
275    The first column of the indices corresponds to the row in `source` and the
276    second column corresponds to the index of the split component in this row.
277  """
278  if sep is None:
279    sep = ""
280  sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
281  source = ops.convert_to_tensor(source, dtype=dtypes.string)
282
283  indices, values, shape = gen_string_ops.string_split_v2(
284      source, sep=sep, maxsplit=maxsplit)
285  indices.set_shape([None, 2])
286  values.set_shape([None])
287  shape.set_shape([2])
288  return sparse_tensor.SparseTensor(indices, values, shape)
289
290
291def _reduce_join_reduction_dims(x, axis):
292  """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise."""
293  if axis is not None:
294    return axis
295  else:
296    # Fast path: avoid creating Rank and Range ops if ndims is known.
297    if x.get_shape().ndims is not None:
298      return constant_op.constant(
299          np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32)
300
301    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
302    return math_ops.range(array_ops.rank(x) - 1, -1, -1)
303
304
305@tf_export(v1=["strings.reduce_join", "reduce_join"])
306@dispatch.add_dispatch_support
307@deprecation.deprecated_args(None,
308                             "keep_dims is deprecated, use keepdims instead",
309                             "keep_dims")
310@deprecation.deprecated_endpoints("reduce_join")
311def reduce_join(inputs, axis=None,  # pylint: disable=missing-docstring
312                keep_dims=None,
313                separator="",
314                name=None,
315                reduction_indices=None,
316                keepdims=None):
317  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
318                                                    "keep_dims", keep_dims)
319  if keep_dims is None:
320    keep_dims = False
321  axis = deprecation.deprecated_argument_lookup("axis", axis,
322                                                "reduction_indices",
323                                                reduction_indices)
324  return reduce_join_v2(
325      inputs=inputs,
326      axis=axis,
327      keepdims=keepdims,
328      separator=separator,
329      name=name)
330
331
332@tf_export("strings.reduce_join", v1=[])
333@dispatch.add_dispatch_support
334def reduce_join_v2(  # pylint: disable=missing-docstring
335    inputs,
336    axis=None,
337    keepdims=False,
338    separator="",
339    name=None):
340  """Joins all strings into a single string, or joins along an axis.
341
342  This is the reduction operation for the elementwise `tf.strings.join` op.
343
344  >>> tf.strings.reduce_join([['abc','123'],
345  ...                         ['def','456']]).numpy()
346  b'abc123def456'
347  >>> tf.strings.reduce_join([['abc','123'],
348  ...                         ['def','456']], axis=-1).numpy()
349  array([b'abc123', b'def456'], dtype=object)
350  >>> tf.strings.reduce_join([['abc','123'],
351  ...                         ['def','456']],
352  ...                        axis=-1,
353  ...                        separator=" ").numpy()
354  array([b'abc 123', b'def 456'], dtype=object)
355
356  Args:
357    inputs: A `tf.string` tensor.
358    axis: Which axis to join along. The default behavior is to join all
359      elements, producing a scalar.
360    keepdims: If true, retains reduced dimensions with length 1.
361    separator: a string added between each string being joined.
362    name: A name for the operation (optional).
363
364  Returns:
365    A `tf.string` tensor.
366  """
367  with ops.name_scope(None, "ReduceJoin", [inputs, axis]):
368    inputs_t = ops.convert_to_tensor(inputs)
369    axis = _reduce_join_reduction_dims(inputs_t, axis)
370    return gen_string_ops.reduce_join(
371        inputs=inputs_t,
372        reduction_indices=axis,
373        keep_dims=keepdims,
374        separator=separator,
375        name=name)
376
377reduce_join.__doc__ = reduce_join_v2.__doc__
378
379
380# This wrapper provides backwards compatibility for code that predates the
381# unit argument and that passed 'name' as a positional argument.
382@tf_export(v1=["strings.length"])
383@dispatch.add_dispatch_support
384def string_length(input, name=None, unit="BYTE"):
385  """Computes the length of each string given in the input tensor.
386
387  >>> strings = tf.constant(['Hello','TensorFlow', '��'])
388  >>> tf.strings.length(strings).numpy() # default counts bytes
389  array([ 5, 10, 4], dtype=int32)
390  >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy()
391  array([ 5, 10, 1], dtype=int32)
392
393  Args:
394    input: A `Tensor` of type `string`. The strings for which to compute the
395      length for each element.
396    name: A name for the operation (optional).
397    unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to
398      `"BYTE"`. The unit that is counted to compute string length.  One of:
399        `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for
400        the number of UTF-8 encoded Unicode code points in each string). Results
401        are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain
402        structurally valid UTF-8.
403
404  Returns:
405    A `Tensor` of type `int32`, containing the length of the input string in
406    the same element of the input tensor.
407  """
408  return gen_string_ops.string_length(input, unit=unit, name=name)
409
410
411@tf_export("strings.length", v1=[])
412@dispatch.add_dispatch_support
413def string_length_v2(input, unit="BYTE", name=None):
414  return gen_string_ops.string_length(input, unit=unit, name=name)
415
416
417string_length_v2.__doc__ = gen_string_ops.string_length.__doc__
418
419
420@tf_export(v1=["substr"])
421@dispatch.add_dispatch_support
422@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
423def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
424  return substr(input, pos, len, name=name, unit=unit)
425
426substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
427
428
429@tf_export(v1=["strings.substr"])
430@dispatch.add_dispatch_support
431def substr(input, pos, len, name=None, unit="BYTE"):
432  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
433
434substr.__doc__ = gen_string_ops.substr.__doc__
435
436
437@tf_export("strings.substr", v1=[])
438@dispatch.add_dispatch_support
439def substr_v2(input, pos, len, unit="BYTE", name=None):
440  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
441
442substr_v2.__doc__ = gen_string_ops.substr.__doc__
443
444
445ops.NotDifferentiable("RegexReplace")
446ops.NotDifferentiable("StringToHashBucket")
447ops.NotDifferentiable("StringToHashBucketFast")
448ops.NotDifferentiable("StringToHashBucketStrong")
449ops.NotDifferentiable("ReduceJoin")
450ops.NotDifferentiable("StringJoin")
451ops.NotDifferentiable("StringSplit")
452ops.NotDifferentiable("AsString")
453ops.NotDifferentiable("EncodeBase64")
454ops.NotDifferentiable("DecodeBase64")
455
456
457@tf_export("strings.to_number", v1=[])
458@dispatch.add_dispatch_support
459def string_to_number(input, out_type=dtypes.float32, name=None):
460  r"""Converts each string in the input Tensor to the specified numeric type.
461
462  (Note that int32 overflow results in an error while float overflow
463  results in a rounded value.)
464
465  Examples:
466
467  >>> tf.strings.to_number("1.55")
468  <tf.Tensor: shape=(), dtype=float32, numpy=1.55>
469  >>> tf.strings.to_number("3", tf.int32)
470  <tf.Tensor: shape=(), dtype=int32, numpy=3>
471
472  Args:
473    input: A `Tensor` of type `string`.
474    out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32,
475      tf.int64`. Defaults to `tf.float32`.
476      The numeric type to interpret each string in `string_tensor` as.
477    name: A name for the operation (optional).
478
479  Returns:
480    A `Tensor` of type `out_type`.
481  """
482  return gen_parsing_ops.string_to_number(input, out_type, name)
483
484
485@tf_export(v1=["strings.to_number", "string_to_number"])
486@dispatch.add_dispatch_support
487def string_to_number_v1(
488    string_tensor=None,
489    out_type=dtypes.float32,
490    name=None,
491    input=None):
492  string_tensor = deprecation.deprecated_argument_lookup(
493      "input", input, "string_tensor", string_tensor)
494  return gen_parsing_ops.string_to_number(string_tensor, out_type, name)
495
496string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__
497
498
499@tf_export("strings.to_hash_bucket", v1=[])
500@dispatch.add_dispatch_support
501def string_to_hash_bucket(input, num_buckets, name=None):
502  # pylint: disable=line-too-long
503  r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
504
505  The hash function is deterministic on the content of the string within the
506  process.
507
508  Note that the hash function may change from time to time.
509  This functionality will be deprecated and it's recommended to use
510  `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`.
511
512  Examples:
513
514  >>> tf.strings.to_hash_bucket(["Hello", "TensorFlow", "2.x"], 3)
515  <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 0, 1])>
516
517  Args:
518    input: A `Tensor` of type `string`.
519    num_buckets: An `int` that is `>= 1`. The number of buckets.
520    name: A name for the operation (optional).
521
522  Returns:
523    A `Tensor` of type `int64`.
524  """
525  # pylint: enable=line-too-long
526  return gen_string_ops.string_to_hash_bucket(input, num_buckets, name)
527
528
529@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
530@dispatch.add_dispatch_support
531def string_to_hash_bucket_v1(
532    string_tensor=None,
533    num_buckets=None,
534    name=None,
535    input=None):
536  string_tensor = deprecation.deprecated_argument_lookup(
537      "input", input, "string_tensor", string_tensor)
538  return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name)
539
540string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
541
542
543@tf_export("strings.join", v1=["strings.join", "string_join"])
544@dispatch.add_dispatch_support
545@deprecation.deprecated_endpoints("string_join")
546@dispatch.add_dispatch_support
547def string_join(inputs, separator="", name=None):
548  """Perform element-wise concatenation of a list of string tensors.
549
550  Given a list of string tensors of same shape, performs element-wise
551  concatenation of the strings of the same index in all tensors.
552
553
554  >>> tf.strings.join(['abc','def']).numpy()
555  b'abcdef'
556  >>> tf.strings.join([['abc','123'],
557  ...                  ['def','456'],
558  ...                  ['ghi','789']]).numpy()
559  array([b'abcdefghi', b'123456789'], dtype=object)
560  >>> tf.strings.join([['abc','123'],
561  ...                  ['def','456']],
562  ...                  separator=" ").numpy()
563  array([b'abc def', b'123 456'], dtype=object)
564
565  The reduction version of this elementwise operation is
566  `tf.strings.reduce_join`
567
568  Args:
569    inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype.
570    separator: A string added between each string being joined.
571    name: A name for the operation (optional).
572
573  Returns:
574    A `tf.string` tensor.
575  """
576  return gen_string_ops.string_join(inputs, separator=separator, name=name)
577