• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Special functions that only make sense for AutoGraph.
16
17These functions are meant to ensure feature parity between Python and AutoGraph,
18so that the exact same code works in both modes. In general, AutoGraph will
19replace these calls.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.autograph.operators import data_structures
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import tensor_util
29
30
31def _validate_list_constructor(elements, element_dtype, element_shape):
32  """Validates the inputs of tensor_list."""
33  if element_dtype is not None and element_shape is not None:
34    return
35  if tensor_util.is_tf_type(elements):
36    return
37  if isinstance(elements, (list, tuple)):
38    if elements:
39      return
40    else:
41      raise ValueError(
42          'element_dtype and element_shape are required when elements are'
43          ' empty')
44
45  raise ValueError(
46      'unknown type for elements: {}; only Tensor, list and tuple are'
47      ' allowed'.format(type(elements)))
48
49
50def match_staging_level(value, like_value):
51  """Casts a value to be staged at the same level as another."""
52  if tensor_util.is_tf_type(like_value):
53    return constant_op.constant(value)
54  return value
55
56
57def tensor_list(elements,
58                element_dtype=None,
59                element_shape=None,
60                use_tensor_array=False):
61  """Creates an tensor list and populates it with the given elements.
62
63  This function provides a more uniform access to tensor lists and tensor
64  arrays, and allows optional initialization.
65
66  Note: this function is a simplified wrapper. If you need greater control,
67  it is recommended to use the underlying implementation directly.
68
69  Args:
70    elements: Iterable[tf.Tensor, ...], the elements to initially fill the list
71        with
72    element_dtype: Optional[tf.DType], data type for the elements in the list;
73        required if the list is empty
74    element_shape: Optional[tf.TensorShape], shape for the elements in the list;
75        required if the list is empty
76    use_tensor_array: bool, whether to use the more compatible but restrictive
77        tf.TensorArray implementation
78  Returns:
79    Union[tf.Tensor, tf.TensorArray], the new list.
80  Raises:
81    ValueError: for invalid arguments
82  """
83  _validate_list_constructor(elements, element_dtype, element_shape)
84  if use_tensor_array:
85    return data_structures.tf_tensor_array_new(elements, element_dtype,
86                                               element_shape)
87  else:
88    return data_structures.tf_tensor_list_new(elements, element_dtype,
89                                              element_shape)
90
91
92def stack(list_or_tensor, element_dtype=None, strict=True):
93  """Stacks the input, if it admits the notion of stacking.
94
95  For example, a list of tensors can be stacked into a larger tensor. This
96  function is similar to tf.stack, but it accepts non-lists and lists of
97  non-tensors as arguments. In the latter case, the function does nothing.
98
99  Args:
100    list_or_tensor: Any
101    element_dtype: tf.DType, optional dtypedtype for the elements in the list.
102        Required if the input is stackable, and the list is untyped.
103    strict: bool, if True an error is raised if the input is not stackable.
104        Otherwise the function is a no-op.
105
106  Returns:
107    Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
108    if strict=False, the result will be list_or_tensor.
109
110  Raises:
111    ValueError: if strict=True and the input is not stackable.
112  """
113  if strict:
114    def raise_error(x):
115      raise ValueError('%s must be stackable when strict=True' % x)
116    original_call = raise_error
117  else:
118    original_call = lambda x: x
119  return data_structures.list_stack(
120      list_or_tensor,
121      data_structures.ListStackOpts(
122          element_dtype=element_dtype, original_call=original_call))
123