• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators specific to slicing operations."""
16
17import collections
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.ops import gen_array_ops
22from tensorflow.python.ops import gen_string_ops
23from tensorflow.python.ops import list_ops
24from tensorflow.python.ops import tensor_array_ops
25
26
27# TODO(mdan): Support extended slices.
28
29
30class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
31  pass
32
33
34def get_item(target, i, opts):
35  """The slice read operator (i.e. __getitem__).
36
37  Note: it is unspecified whether target will be mutated or not. In general,
38  if target is mutable (like Python lists), it will be mutated.
39
40  Args:
41    target: An entity that supports getitem semantics.
42    i: Index to read from.
43    opts: A GetItemOpts object.
44
45  Returns:
46    The read element.
47
48  Raises:
49    ValueError: if target is not of a supported type.
50  """
51  assert isinstance(opts, GetItemOpts)
52
53  if isinstance(target, tensor_array_ops.TensorArray):
54    return _tf_tensorarray_get_item(target, i)
55  elif tensor_util.is_tf_type(target):
56    if target.dtype == dtypes.variant:
57      return _tf_tensor_list_get_item(target, i, opts)
58    elif target.dtype == dtypes.string and target.shape.ndims == 0:
59      return _tf_tensor_string_get_item(target, i)
60    else:
61      return _tf_tensor_get_item(target, i)
62  else:
63    return _py_get_item(target, i)
64
65
66def _tf_tensorarray_get_item(target, i):
67  """Overload of get_item that stages a TensorArray read."""
68  return target.read(i)
69
70
71def _tf_tensor_list_get_item(target, i, opts):
72  """Overload of get_item that stages a Tensor list read."""
73  if opts.element_dtype is None:
74    raise ValueError('cannot retrieve from a list without knowing its '
75                     'element type; use set_element_type to annotate it')
76  x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
77  return x
78
79
80def _tf_tensor_get_item(target, i):
81  """Overload of get_item that stages a Tensor (not Tensor list) read."""
82  return target[i]
83
84
85def _tf_tensor_string_get_item(target, i):
86  """Overload of get_item that stages a Tensor string read."""
87  x = gen_string_ops.substr(target, i, 1)
88  return x
89
90
91def _py_get_item(target, i):
92  """Overload of get_item that executes a Python list modification."""
93  return target[i]
94
95
96def set_item(target, i, x):
97  """The slice write operator (i.e. __setitem__).
98
99  Note: it is unspecified whether target will be mutated or not. In general,
100  if target is mutable (like Python lists), it will be mutated.
101
102  Args:
103    target: An entity that supports setitem semantics.
104    i: Index to modify.
105    x: The new element value.
106
107  Returns:
108    Same as target, after the update was performed.
109
110  Raises:
111    ValueError: if target is not of a supported type.
112  """
113  if isinstance(target, tensor_array_ops.TensorArray):
114    return _tf_tensorarray_set_item(target, i, x)
115  elif tensor_util.is_tf_type(target):
116    if target.dtype == dtypes.variant:
117      return _tf_tensor_list_set_item(target, i, x)
118    else:
119      return _tf_tensor_set_item(target, i, x)
120  else:
121    return _py_set_item(target, i, x)
122
123
124def _tf_tensorarray_set_item(target, i, x):
125  """Overload of set_item that stages a TensorArray write."""
126  return target.write(i, x)
127
128
129def _tf_tensor_list_set_item(target, i, x):
130  """Overload of set_item that stages a Tensor list update."""
131  return list_ops.tensor_list_set_item(target, i, x)
132
133
134def _tf_tensor_set_item(target, i, x):
135  """Overload of set_item that stages a Tensor scatter update."""
136  return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,))
137
138
139def _py_set_item(target, i, x):
140  """Overload of set_item that executes a Python list modification."""
141  target[i] = x
142  return target
143