• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17"""Operations for working with string Tensors."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_parsing_ops
32from tensorflow.python.ops import gen_string_ops
33from tensorflow.python.ops import math_ops
34
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37# pylint: disable=g-bad-import-order
38from tensorflow.python.ops.gen_string_ops import *
39from tensorflow.python.util import compat as util_compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43# pylint: enable=g-bad-import-order
44# pylint: enable=wildcard-import
45
46
47# pylint: disable=redefined-builtin
48@tf_export("strings.regex_full_match")
49@dispatch.add_dispatch_support
50def regex_full_match(input, pattern, name=None):
51  r"""Match elements of `input` with regex `pattern`.
52
53  Args:
54    input: string `Tensor`, the source strings to process.
55    pattern: string or scalar string `Tensor`, regular expression to use,
56      see more details at https://github.com/google/re2/wiki/Syntax
57    name: Name of the op.
58
59  Returns:
60    bool `Tensor` of the same shape as `input` with match results.
61  """
62  if isinstance(pattern, util_compat.bytes_or_text_types):
63    # When `pattern` is static through the life of the op we can
64    # use a version which performs the expensive regex compilation once at
65    # creation time.
66    return gen_string_ops.static_regex_full_match(
67        input=input, pattern=pattern, name=name)
68  return gen_string_ops.regex_full_match(
69      input=input, pattern=pattern, name=name)
70
71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
72
73
74@tf_export(
75    "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
76@deprecation.deprecated_endpoints("regex_replace")
77@dispatch.add_dispatch_support
78def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
79  r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
80
81  >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>",
82  ...                          "<[^>]+>", " ")
83  <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags.  contains html '>
84
85  Args:
86    input: string `Tensor`, the source strings to process.
87    pattern: string or scalar string `Tensor`, regular expression to use,
88      see more details at https://github.com/google/re2/wiki/Syntax
89    rewrite: string or scalar string `Tensor`, value to use in match
90      replacement, supports backslash-escaped digits (\1 to \9) can be to insert
91      text matching corresponding parenthesized group.
92    replace_global: `bool`, if `True` replace all non-overlapping matches,
93      else replace only the first match.
94    name: A name for the operation (optional).
95
96  Returns:
97    string `Tensor` of the same shape as `input` with specified replacements.
98  """
99  if (isinstance(pattern, util_compat.bytes_or_text_types) and
100      isinstance(rewrite, util_compat.bytes_or_text_types)):
101    # When `pattern` and `rewrite` are static through the life of the op we can
102    # use a version which performs the expensive regex compilation once at
103    # creation time.
104    return gen_string_ops.static_regex_replace(
105        input=input, pattern=pattern,
106        rewrite=rewrite, replace_global=replace_global,
107        name=name)
108  return gen_string_ops.regex_replace(
109      input=input, pattern=pattern,
110      rewrite=rewrite, replace_global=replace_global,
111      name=name)
112
113
114@tf_export("strings.format")
115def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
116  r"""Formats a string template using a list of tensors.
117
118  Formats a string template using a list of tensors, abbreviating tensors by
119  only printing the first and last `summarize` elements of each dimension
120  (recursively). If formatting only one tensor into a template, the tensor does
121  not have to be wrapped in a list.
122
123  Example:
124    Formatting a single-tensor template:
125    ```python
126    sess = tf.compat.v1.Session()
127    with sess.as_default():
128        tensor = tf.range(10)
129        formatted = tf.strings.format("tensor: {}, suffix", tensor)
130        out = sess.run(formatted)
131        expected = "tensor: [0 1 2 ... 7 8 9], suffix"
132
133        assert(out.decode() == expected)
134    ```
135
136    Formatting a multi-tensor template:
137    ```python
138    sess = tf.compat.v1.Session()
139    with sess.as_default():
140        tensor_one = tf.reshape(tf.range(100), [10, 10])
141        tensor_two = tf.range(10)
142        formatted = tf.strings.format("first: {}, second: {}, suffix",
143          (tensor_one, tensor_two))
144
145        out = sess.run(formatted)
146        expected = ("first: [[0 1 2 ... 7 8 9]\n"
147              " [10 11 12 ... 17 18 19]\n"
148              " [20 21 22 ... 27 28 29]\n"
149              " ...\n"
150              " [70 71 72 ... 77 78 79]\n"
151              " [80 81 82 ... 87 88 89]\n"
152              " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
153
154        assert(out.decode() == expected)
155    ```
156
157  Args:
158    template: A string template to format tensor values into.
159    inputs: A list of `Tensor` objects, or a single Tensor.
160      The list of tensors to format into the template string. If a solitary
161      tensor is passed in, the input tensor will automatically be wrapped as a
162      list.
163    placeholder: An optional `string`. Defaults to `{}`.
164      At each placeholder occurring in the template, a subsequent tensor
165      will be inserted.
166    summarize: An optional `int`. Defaults to `3`.
167      When formatting the tensors, show the first and last `summarize`
168      entries of each tensor dimension (recursively). If set to -1, all
169      elements of the tensor will be shown.
170    name: A name for the operation (optional).
171
172  Returns:
173    A scalar `Tensor` of type `string`.
174
175  Raises:
176    ValueError: if the number of placeholders does not match the number of
177      inputs.
178  """
179  # If there is only one tensor to format, we will automatically wrap it in a
180  # list to simplify the user experience
181  if tensor_util.is_tensor(inputs):
182    inputs = [inputs]
183  if template.count(placeholder) != len(inputs):
184    raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
185                     " provided as input" % (template.count(placeholder),
186                                             len(inputs)))
187
188  return gen_string_ops.string_format(inputs,
189                                      template=template,
190                                      placeholder=placeholder,
191                                      summarize=summarize,
192                                      name=name)
193
194
195# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
196# defines a wrapper for this function.
197def string_split(source, sep=None, skip_empty=True, delimiter=None):  # pylint: disable=invalid-name
198  """Split elements of `source` based on `delimiter` into a `SparseTensor`.
199
200  Let N be the size of source (typically N will be the batch size). Split each
201  element of `source` based on `delimiter` and return a `SparseTensor`
202  containing the split tokens. Empty tokens are ignored.
203
204  If `sep` is an empty string, each element of the `source` is split
205  into individual strings, each containing one byte. (This includes splitting
206  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
207  treated as a set of delimiters with each considered a potential split point.
208
209  For example:
210  N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
211  will be
212
213  st.indices = [0, 0;
214                0, 1;
215                1, 0;
216                1, 1;
217                1, 2]
218  st.shape = [2, 3]
219  st.values = ['hello', 'world', 'a', 'b', 'c']
220
221  Args:
222    source: `1-D` string `Tensor`, the strings to split.
223    sep: `0-D` string `Tensor`, the delimiter character, the string should
224      be length 0 or 1. Default is ' '.
225    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
226    delimiter: deprecated alias for `sep`.
227
228  Raises:
229    ValueError: If delimiter is not a string.
230
231  Returns:
232    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
233    The first column of the indices corresponds to the row in `source` and the
234    second column corresponds to the index of the split component in this row.
235  """
236  delimiter = deprecation.deprecated_argument_lookup(
237      "sep", sep, "delimiter", delimiter)
238
239  if delimiter is None:
240    delimiter = " "
241  delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
242  source = ops.convert_to_tensor(source, dtype=dtypes.string)
243
244  indices, values, shape = gen_string_ops.string_split(
245      source, delimiter=delimiter, skip_empty=skip_empty)
246  indices.set_shape([None, 2])
247  values.set_shape([None])
248  shape.set_shape([2])
249  return sparse_tensor.SparseTensor(indices, values, shape)
250
251
252# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
253# defines a wrapper for this function.
254def string_split_v2(source, sep=None, maxsplit=-1):
255  """Split elements of `source` based on `sep` into a `SparseTensor`.
256
257  Let N be the size of source (typically N will be the batch size). Split each
258  element of `source` based on `sep` and return a `SparseTensor`
259  containing the split tokens. Empty tokens are ignored.
260
261  For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
262  then the output will be
263
264  st.indices = [0, 0;
265                0, 1;
266                1, 0;
267                1, 1;
268                1, 2]
269  st.shape = [2, 3]
270  st.values = ['hello', 'world', 'a', 'b', 'c']
271
272  If `sep` is given, consecutive delimiters are not grouped together and are
273  deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
274  sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
275  string, consecutive whitespace are regarded as a single separator, and the
276  result will contain no empty strings at the start or end if the string has
277  leading or trailing whitespace.
278
279  Note that the above mentioned behavior matches python's str.split.
280
281  Args:
282    source: `1-D` string `Tensor`, the strings to split.
283    sep: `0-D` string `Tensor`, the delimiter character.
284    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
285
286  Raises:
287    ValueError: If sep is not a string.
288
289  Returns:
290    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
291    The first column of the indices corresponds to the row in `source` and the
292    second column corresponds to the index of the split component in this row.
293  """
294  if sep is None:
295    sep = ""
296  sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
297  source = ops.convert_to_tensor(source, dtype=dtypes.string)
298
299  indices, values, shape = gen_string_ops.string_split_v2(
300      source, sep=sep, maxsplit=maxsplit)
301  indices.set_shape([None, 2])
302  values.set_shape([None])
303  shape.set_shape([2])
304  return sparse_tensor.SparseTensor(indices, values, shape)
305
306
307def _reduce_join_reduction_dims(x, axis):
308  """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise."""
309  if axis is not None:
310    return axis
311  else:
312    # Fast path: avoid creating Rank and Range ops if ndims is known.
313    if x.get_shape().ndims is not None:
314      return constant_op.constant(
315          np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32)
316
317    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
318    return math_ops.range(array_ops.rank(x) - 1, -1, -1)
319
320
321@tf_export(v1=["strings.reduce_join", "reduce_join"])
322@deprecation.deprecated_args(None,
323                             "keep_dims is deprecated, use keepdims instead",
324                             "keep_dims")
325@deprecation.deprecated_endpoints("reduce_join")
326def reduce_join(inputs, axis=None,  # pylint: disable=missing-docstring
327                keep_dims=None,
328                separator="",
329                name=None,
330                reduction_indices=None,
331                keepdims=None):
332  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
333                                                    "keep_dims", keep_dims)
334  if keep_dims is None:
335    keep_dims = False
336  axis = deprecation.deprecated_argument_lookup("axis", axis,
337                                                "reduction_indices",
338                                                reduction_indices)
339  return reduce_join_v2(
340      inputs=inputs,
341      axis=axis,
342      keepdims=keepdims,
343      separator=separator,
344      name=name)
345
346
347@tf_export("strings.reduce_join", v1=[])
348@dispatch.add_dispatch_support
349def reduce_join_v2(  # pylint: disable=missing-docstring
350    inputs,
351    axis=None,
352    keepdims=False,
353    separator="",
354    name=None):
355  """Joins all strings into a single string, or joins along an axis.
356
357  >>> tf.strings.reduce_join([['abc','123'],
358  ...                         ['def','456']]).numpy()
359  b'abc123def456'
360  >>> tf.strings.reduce_join([['abc','123'],
361  ...                         ['def','456']], axis=-1).numpy()
362  array([b'abc123', b'def456'], dtype=object)
363  >>> tf.strings.reduce_join([['abc','123'],
364  ...                         ['def','456']],
365  ...                        axis=-1,
366  ...                        separator=" ").numpy()
367  array([b'abc 123', b'def 456'], dtype=object)
368
369  Args:
370    inputs: A `tf.string` tensor.
371    axis: Which axis to join along. The default behavior is to join all
372      elements, producing a scalar.
373    keepdims: If true, retains reduced dimensions with length 1.
374    separator: a string added between each string being joined.
375    name: A name for the operation (optional).
376
377  Returns:
378    A `tf.string` tensor.
379  """
380  with ops.name_scope(None, "ReduceJoin", [inputs, axis]):
381    inputs_t = ops.convert_to_tensor(inputs)
382    axis = _reduce_join_reduction_dims(inputs_t, axis)
383    return gen_string_ops.reduce_join(
384        inputs=inputs_t,
385        reduction_indices=axis,
386        keep_dims=keepdims,
387        separator=separator,
388        name=name)
389
390reduce_join.__doc__ = reduce_join_v2.__doc__
391
392
393# This wrapper provides backwards compatibility for code that predates the
394# unit argument and that passed 'name' as a positional argument.
395@tf_export(v1=["strings.length"])
396@dispatch.add_dispatch_support
397def string_length(input, name=None, unit="BYTE"):
398  """Computes the length of each string given in the input tensor.
399
400  >>> strings = tf.constant(['Hello','TensorFlow', '��'])
401  >>> tf.strings.length(strings).numpy() # default counts bytes
402  array([ 5, 10, 4], dtype=int32)
403  >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy()
404  array([ 5, 10, 1], dtype=int32)
405
406  Args:
407    input: A `Tensor` of type `string`. The strings for which to compute the
408      length for each element.
409    name: A name for the operation (optional).
410    unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to
411      `"BYTE"`. The unit that is counted to compute string length.  One of:
412        `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for
413        the number of UTF-8 encoded Unicode code points in each string). Results
414        are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain
415        structurally valid UTF-8.
416
417  Returns:
418    A `Tensor` of type `int32`, containing the length of the input string in
419    the same element of the input tensor.
420  """
421  return gen_string_ops.string_length(input, unit=unit, name=name)
422
423
424@tf_export("strings.length", v1=[])
425@dispatch.add_dispatch_support
426def string_length_v2(input, unit="BYTE", name=None):
427  return gen_string_ops.string_length(input, unit=unit, name=name)
428
429
430string_length_v2.__doc__ = gen_string_ops.string_length.__doc__
431
432
433@tf_export(v1=["substr"])
434@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
435def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
436  return substr(input, pos, len, name=name, unit=unit)
437
438substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
439
440
441@tf_export(v1=["strings.substr"])
442@dispatch.add_dispatch_support
443def substr(input, pos, len, name=None, unit="BYTE"):
444  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
445
446substr.__doc__ = gen_string_ops.substr.__doc__
447
448
449@tf_export("strings.substr", v1=[])
450@dispatch.add_dispatch_support
451def substr_v2(input, pos, len, unit="BYTE", name=None):
452  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
453
454substr_v2.__doc__ = gen_string_ops.substr.__doc__
455
456
457ops.NotDifferentiable("RegexReplace")
458ops.NotDifferentiable("StringToHashBucket")
459ops.NotDifferentiable("StringToHashBucketFast")
460ops.NotDifferentiable("StringToHashBucketStrong")
461ops.NotDifferentiable("ReduceJoin")
462ops.NotDifferentiable("StringJoin")
463ops.NotDifferentiable("StringSplit")
464ops.NotDifferentiable("AsString")
465ops.NotDifferentiable("EncodeBase64")
466ops.NotDifferentiable("DecodeBase64")
467
468
469@tf_export("strings.to_number", v1=[])
470@dispatch.add_dispatch_support
471def string_to_number(input, out_type=dtypes.float32, name=None):
472  r"""Converts each string in the input Tensor to the specified numeric type.
473
474  (Note that int32 overflow results in an error while float overflow
475  results in a rounded value.)
476
477  Args:
478    input: A `Tensor` of type `string`.
479    out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32,
480      tf.int64`. Defaults to `tf.float32`.
481      The numeric type to interpret each string in `string_tensor` as.
482    name: A name for the operation (optional).
483
484  Returns:
485    A `Tensor` of type `out_type`.
486  """
487  return gen_parsing_ops.string_to_number(input, out_type, name)
488
489
490@tf_export(v1=["strings.to_number", "string_to_number"])
491def string_to_number_v1(
492    string_tensor=None,
493    out_type=dtypes.float32,
494    name=None,
495    input=None):
496  string_tensor = deprecation.deprecated_argument_lookup(
497      "input", input, "string_tensor", string_tensor)
498  return gen_parsing_ops.string_to_number(string_tensor, out_type, name)
499
500string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__
501
502
503@tf_export("strings.to_hash_bucket", v1=[])
504@dispatch.add_dispatch_support
505def string_to_hash_bucket(input, num_buckets, name=None):
506  # pylint: disable=line-too-long
507  r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
508
509  The hash function is deterministic on the content of the string within the
510  process.
511
512  Note that the hash function may change from time to time.
513  This functionality will be deprecated and it's recommended to use
514  `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`.
515
516  Args:
517    input: A `Tensor` of type `string`.
518    num_buckets: An `int` that is `>= 1`. The number of buckets.
519    name: A name for the operation (optional).
520
521  Returns:
522    A `Tensor` of type `int64`.
523  """
524  # pylint: enable=line-too-long
525  return gen_string_ops.string_to_hash_bucket(input, num_buckets, name)
526
527
528@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
529def string_to_hash_bucket_v1(
530    string_tensor=None,
531    num_buckets=None,
532    name=None,
533    input=None):
534  string_tensor = deprecation.deprecated_argument_lookup(
535      "input", input, "string_tensor", string_tensor)
536  return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name)
537
538string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
539
540
541@tf_export("strings.join", v1=["strings.join", "string_join"])
542@deprecation.deprecated_endpoints("string_join")
543@dispatch.add_dispatch_support
544def string_join(inputs, separator="", name=None):
545  """Perform element-wise concatenation of a list of string tensors.
546
547  Given a list of string tensors of same shape, performs element-wise
548  concatenation of the strings of the same index in all tensors.
549
550
551  >>> tf.strings.join(['abc','def']).numpy()
552  b'abcdef'
553  >>> tf.strings.join([['abc','123'],
554  ...                  ['def','456'],
555  ...                  ['ghi','789']]).numpy()
556  array([b'abcdefghi', b'123456789'], dtype=object)
557  >>> tf.strings.join([['abc','123'],
558  ...                  ['def','456']],
559  ...                  separator=" ").numpy()
560  array([b'abc def', b'123 456'], dtype=object)
561
562  Args:
563    inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype.
564    separator: A string added between each string being joined.
565    name: A name for the operation (optional).
566
567  Returns:
568    A `tf.string` tensor.
569  """
570  return gen_string_ops.string_join(inputs, separator=separator, name=name)
571