• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Inplace operations.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.util import deprecation
28
29
30def _inplace_helper(x, i, v, op):
31  """Applies an inplace op on (x, i, v).
32
33  op is one of gen_array_ops.alias_inplace_update,
34  gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub.
35
36  If i is None, x and v must be the same shape. Computes
37    x op v;
38  If i is a scalar, x has a rank 1 higher than v's. Computes
39    x[i, :] op v;
40  Otherwise, x and v must have the same rank. Computes
41    x[i, :] op v;
42
43  Args:
44    x: A Tensor.
45    i: None, a scalar or a vector.
46    v: A Tensor.
47    op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub.
48
49  Returns:
50    Returns x.
51
52  """
53  x = ops.convert_to_tensor(x)
54  v = ops.convert_to_tensor(v, x.dtype)
55  if i is None:
56    # Full tensor.
57    return array_ops.reshape(
58        op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])),
59        array_ops.shape(x))
60  i = math_ops.cast(i, dtypes.int32)
61  if i.get_shape().ndims == 0:
62    # Single 0-dim update.
63    return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0))
64  return op(x, i, v)
65
66
67@deprecation.deprecated(
68    None,
69    ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality '
70     'with well-defined read-write semantics.'))
71def alias_inplace_update(x, i, v):
72  """Applies an inplace update on input x at index i with value v. Aliases x.
73
74  If i is None, x and v must be the same shape. Computes
75    x = v;
76  If i is a scalar, x has a rank 1 higher than v's. Computes
77    x[i, :] = v;
78  Otherwise, x and v must have the same rank. Computes
79    x[i, :] = v;
80
81  Args:
82    x: A Tensor.
83    i: None, a scalar or a vector.
84    v: A Tensor.
85
86  Returns:
87    Returns x.
88
89  """
90  return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
91
92
93@deprecation.deprecated(
94    None,
95    ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality '
96     'with well-defined read-write semantics.'))
97def alias_inplace_add(x, i, v):
98  """Applies an inplace add on input x at index i with value v. Aliases x.
99
100  If i is None, x and v must be the same shape. Computes
101    x += v;
102  If i is a scalar, x has a rank 1 higher than v's. Computes
103    x[i, :] += v;
104  Otherwise, x and v must have the same rank. Computes
105    x[i, :] += v;
106
107  Args:
108    x: A Tensor.
109    i: None, a scalar or a vector.
110    v: A Tensor.
111
112  Returns:
113    Returns x.
114
115  """
116  return _inplace_helper(x, i, v, gen_array_ops.inplace_add)
117
118
119@deprecation.deprecated(
120    None,
121    ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality '
122     'with well-defined read-write semantics.'))
123def alias_inplace_sub(x, i, v):
124  """Applies an inplace sub on input x at index i with value v. Aliases x.
125
126  If i is None, x and v must be the same shape. Computes
127    x -= v;
128  If i is a scalar, x has a rank 1 higher than v's. Computes
129    x[i, :] -= v;
130  Otherwise, x and v must have the same rank. Computes
131    x[i, :] -= v;
132
133  Args:
134    x: A Tensor.
135    i: None, a scalar or a vector.
136    v: A Tensor.
137
138  Returns:
139    Returns x.
140
141  """
142  return _inplace_helper(x, i, v, gen_array_ops.inplace_sub)
143
144
145def empty_like(x, init=None):
146  """Returns a non-initialized tensor with the same shape and dtype as x.
147
148  Args:
149    x: A Tensor.
150    init: Initialize the returned tensor with the default value of
151      x.dtype(), if True. Otherwise, do not initialize. Defaults to
152      None.
153
154  Returns:
155    A tensor y, whose dtype and shape are the same as those of x.
156    y is guaranteed not to be an alias of x. Upon return, y may contain
157    arbitrary data.
158
159  """
160  x = ops.convert_to_tensor(x)
161  return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init)
162
163
164@deprecation.deprecated(
165    None,
166    ('Prefer tf.tensor_scatter_nd_update, which offers the same functionality '
167     'with well-defined read-write semantics.'))
168def inplace_update(x, i, v):
169  """Applies an inplace update on input x at index i with value v.
170
171  Note that this function is not actually inplace - it allocates
172  a copy of x.  The utility is not avoiding memory copies but rather
173  specifying a sparse update.
174
175  If i is None, x and v must be the same shape. Computes
176    y = x; y = v;
177  If i is a scalar, x has a rank 1 higher than v's. Computes
178    y = x; y[i, :] = v;
179  Otherwise, x and v must have the same rank. Computes
180    y = x; y[i, :] = v;
181
182  Args:
183    x: A Tensor.
184    i: None, a scalar or a vector.
185    v: A Tensor.
186
187  Returns:
188    Returns y, which is guaranteed not to be an alias of x.
189
190  """
191  return alias_inplace_update(gen_array_ops.deep_copy(x), i, v)
192
193
194@deprecation.deprecated(
195    None,
196    ('Prefer tf.tensor_scatter_nd_add, which offers the same functionality '
197     'with well-defined read-write semantics.'))
198def inplace_add(x, i, v):
199  """Applies an inplace add on input x at index i with value v.
200
201  Note that this function is not actually inplace - it allocates
202  a copy of x.  The utility is not avoiding memory copies but rather
203  specifying a sparse update.
204
205  If i is None, x and v must be the same shape. Computes
206    y = x; y += v;
207  If i is a scalar, x has a rank 1 higher than v's. Computes
208    y = x; y[i, :] += v;
209  Otherwise, x and v must have the same rank. Computes
210    y = x; y[i, :] += v;
211
212  Args:
213    x: A Tensor.
214    i: None, a scalar or a vector.
215    v: A Tensor.
216
217  Returns:
218    Returns y, which is guaranteed not to be an alias of x.
219
220  """
221  return alias_inplace_add(gen_array_ops.deep_copy(x), i, v)
222
223
224@deprecation.deprecated(
225    None,
226    ('Prefer tf.tensor_scatter_nd_sub, which offers the same functionality '
227     'with well-defined read-write semantics.'))
228def inplace_sub(x, i, v):
229  """Applies an inplace sub on input x at index i with value v.
230
231  Note that this function is not actually inplace - it allocates
232  a copy of x.  The utility is not avoiding memory copies but rather
233  specifying a sparse update.
234
235  If i is None, x and v must be the same shape. Computes
236    y = x; y -= v;
237  If i is a scalar, x has a rank 1 higher than v's. Computes
238    y = x; y[i, :] -= v;
239  Otherwise, x and v must have the same rank. Computes
240    y = x; y[i, :] -= v;
241
242  Args:
243    x: A Tensor.
244    i: None, a scalar or a vector.
245    v: A Tensor.
246
247  Returns:
248    Returns y, which is guaranteed not to be an alias of x.
249
250  """
251  return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v)
252
253empty = gen_array_ops.empty
254