• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to manipulate lists of tensors."""
16
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.core.framework import full_type_pb2
25from tensorflow.python.framework import cpp_shape_inference_pb2
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_list_ops
32from tensorflow.python.ops import handle_data_util
33# go/tf-wildcard-import
34# pylint: disable=wildcard-import
35from tensorflow.python.ops.gen_list_ops import *
36# pylint: enable=wildcard-import
37from tensorflow.python.util.lazy_loader import LazyLoader
38
39# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
40control_flow_ops = LazyLoader(
41    "control_flow_ops", globals(),
42    "tensorflow.python.ops.control_flow_ops")
43
44
45ops.NotDifferentiable("TensorListConcatLists")
46ops.NotDifferentiable("TensorListElementShape")
47ops.NotDifferentiable("TensorListLength")
48ops.NotDifferentiable("TensorListPushBackBatch")
49
50
51def empty_tensor_list(element_shape,
52                      element_dtype,
53                      max_num_elements=None,
54                      name=None):
55  if max_num_elements is None:
56    max_num_elements = -1
57
58  return gen_list_ops.empty_tensor_list(
59      element_shape=_build_element_shape(element_shape),
60      element_dtype=element_dtype,
61      max_num_elements=max_num_elements,
62      name=name)
63
64
65def _set_handle_data(list_handle, element_shape, element_dtype):
66  """Sets type information on `list_handle` for consistency with graphs."""
67  # TODO(b/169968286): It would be better if we had a consistent story for
68  # creating handle data from eager operations (shared with VarHandleOp).
69  if isinstance(list_handle, ops.EagerTensor):
70    if tensor_util.is_tf_type(element_shape):
71      element_shape = tensor_shape.TensorShape(None)
72    elif not isinstance(element_shape, tensor_shape.TensorShape):
73      element_shape = tensor_shape.TensorShape(element_shape)
74    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
75    handle_data.is_set = True
76    # TODO(b/191472076): This duplicates type inference. Clean up.
77    handle_data.shape_and_type.append(
78        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
79            shape=element_shape.as_proto(),
80            dtype=element_dtype.as_datatype_enum,
81            type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)))
82    list_handle._handle_data = handle_data  # pylint: disable=protected-access
83
84
85def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
86  result = gen_list_ops.tensor_list_reserve(
87      element_shape=_build_element_shape(element_shape),
88      num_elements=num_elements,
89      element_dtype=element_dtype,
90      name=name)
91  # TODO(b/169968286): gen_ops needs to ensure the metadata is properly
92  # populated for eager operations.
93  _set_handle_data(result, element_shape, element_dtype)
94  return result
95
96
97def tensor_list_from_tensor(tensor, element_shape, name=None):
98  tensor = ops.convert_to_tensor(tensor)
99  result = gen_list_ops.tensor_list_from_tensor(
100      tensor=tensor,
101      element_shape=_build_element_shape(element_shape),
102      name=name)
103  _set_handle_data(result, tensor.shape, tensor.dtype)
104  return result
105
106
107def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
108                         name=None):
109  return gen_list_ops.tensor_list_get_item(
110      input_handle=input_handle,
111      index=index,
112      element_shape=_build_element_shape(element_shape),
113      element_dtype=element_dtype,
114      name=name)
115
116
117def tensor_list_pop_back(input_handle, element_dtype, name=None):
118  return gen_list_ops.tensor_list_pop_back(
119      input_handle=input_handle,
120      element_shape=-1,
121      element_dtype=element_dtype,
122      name=name)
123
124
125def tensor_list_gather(input_handle,
126                       indices,
127                       element_dtype,
128                       element_shape=None,
129                       name=None):
130  return gen_list_ops.tensor_list_gather(
131      input_handle=input_handle,
132      indices=indices,
133      element_shape=_build_element_shape(element_shape),
134      element_dtype=element_dtype,
135      name=name)
136
137
138def tensor_list_scatter(tensor,
139                        indices,
140                        element_shape=None,
141                        input_handle=None,
142                        name=None):
143  """Returns a TensorList created or updated by scattering `tensor`."""
144  tensor = ops.convert_to_tensor(tensor)
145  if input_handle is not None:
146    output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
147        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
148    handle_data_util.copy_handle_data(input_handle, output_handle)
149    return output_handle
150  else:
151    output_handle = gen_list_ops.tensor_list_scatter_v2(
152        tensor=tensor,
153        indices=indices,
154        element_shape=_build_element_shape(element_shape),
155        num_elements=-1,
156        name=name)
157    _set_handle_data(output_handle, element_shape, tensor.dtype)
158    return output_handle
159
160
161def tensor_list_stack(input_handle,
162                      element_dtype,
163                      num_elements=-1,
164                      element_shape=None,
165                      name=None):
166  return gen_list_ops.tensor_list_stack(
167      input_handle=input_handle,
168      element_shape=_build_element_shape(element_shape),
169      element_dtype=element_dtype,
170      num_elements=num_elements,
171      name=name)
172
173
174def tensor_list_concat(input_handle, element_dtype, element_shape=None,
175                       name=None):
176  # Ignore the lengths output of TensorListConcat. It is only used during
177  # gradient computation.
178  return gen_list_ops.tensor_list_concat_v2(
179      input_handle=input_handle,
180      element_dtype=element_dtype,
181      element_shape=_build_element_shape(element_shape),
182      leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
183      name=name)[0]
184
185
186def tensor_list_split(tensor, element_shape, lengths, name=None):
187  return gen_list_ops.tensor_list_split(
188      tensor=tensor,
189      element_shape=_build_element_shape(element_shape),
190      lengths=lengths,
191      name=name)
192
193
194def tensor_list_set_item(input_handle,
195                         index,
196                         item,
197                         resize_if_index_out_of_bounds=False,
198                         name=None):
199  """Sets `item` at `index` in input list."""
200  if resize_if_index_out_of_bounds:
201    input_list_size = gen_list_ops.tensor_list_length(input_handle)
202    # TODO(srbs): This could cause some slowdown. Consider fusing resize
203    # functionality in the SetItem op.
204    input_handle = control_flow_ops.cond(
205        index >= input_list_size,
206        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
207            input_handle, index + 1),
208        lambda: input_handle)
209  output_handle = gen_list_ops.tensor_list_set_item(
210      input_handle=input_handle, index=index, item=item, name=name)
211  handle_data_util.copy_handle_data(input_handle, output_handle)
212  return output_handle
213
214
215@ops.RegisterGradient("TensorListPushBack")
216def _PushBackGrad(op, dresult):
217  return gen_list_ops.tensor_list_pop_back(
218      dresult,
219      element_shape=array_ops.shape(op.inputs[1]),
220      element_dtype=op.get_attr("element_dtype"))
221
222
223@ops.RegisterGradient("TensorListPopBack")
224def _PopBackGrad(op, dlist, delement):
225  if dlist is None:
226    dlist = empty_tensor_list(
227        element_dtype=delement.dtype,
228        element_shape=gen_list_ops.tensor_list_element_shape(
229            op.outputs[0], shape_type=dtypes.int32))
230  if delement is None:
231    delement = array_ops.zeros_like(op.outputs[1])
232  return gen_list_ops.tensor_list_push_back(dlist, delement), None
233
234
235@ops.RegisterGradient("TensorListStack")
236def _TensorListStackGrad(unused_op, dtensor):
237  return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
238
239
240@ops.RegisterGradient("TensorListConcat")
241@ops.RegisterGradient("TensorListConcatV2")
242def _TensorListConcatGrad(op, dtensor, unused_dlengths):
243  """Gradient function for TensorListConcat."""
244  dlist = tensor_list_split(
245      dtensor,
246      element_shape=gen_list_ops.tensor_list_element_shape(
247          op.inputs[0], shape_type=dtypes.int32),
248      lengths=op.outputs[1])
249  if op.type == "TensorListConcatV2":
250    return dlist, None, None
251  else:
252    return dlist
253
254
255@ops.RegisterGradient("TensorListSplit")
256def _TensorListSplitGrad(op, dlist):
257  tensor, _, lengths = op.inputs
258  element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
259  element_shape = array_ops.concat([[-1], element_shape], axis=0)
260  return gen_list_ops.tensor_list_concat_v2(
261      dlist,
262      element_shape=element_shape,
263      leading_dims=lengths,
264      element_dtype=op.inputs[0].dtype)[0], None, None
265
266
267@ops.RegisterGradient("TensorListFromTensor")
268def _TensorListFromTensorGrad(op, dlist):
269  """Gradient for TensorListFromTensor."""
270  t = op.inputs[0]
271  if t.shape.dims and t.shape.dims[0].value is not None:
272    num_elements = t.shape.dims[0].value
273  else:
274    num_elements = None
275  if dlist is None:
276    dlist = empty_tensor_list(
277        element_dtype=t.dtype,
278        element_shape=gen_list_ops.tensor_list_element_shape(
279            op.outputs[0], shape_type=dtypes.int32))
280  tensor_grad = gen_list_ops.tensor_list_stack(
281      dlist,
282      element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
283      element_dtype=t.dtype,
284      num_elements=num_elements)
285  shape_grad = None
286  return tensor_grad, shape_grad
287
288
289@ops.RegisterGradient("TensorListGetItem")
290def _TensorListGetItemGrad(op, ditem):
291  """Gradient for TensorListGetItem."""
292  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
293  list_grad = gen_list_ops.tensor_list_set_item(
294      gen_list_ops.tensor_list_reserve(
295          gen_list_ops.tensor_list_element_shape(op.inputs[0],
296                                                 shape_type=dtypes.int32),
297          list_size, element_dtype=ditem.dtype),
298      index=op.inputs[1],
299      item=ditem)
300  index_grad = None
301  element_shape_grad = None
302  return list_grad, index_grad, element_shape_grad
303
304
305@ops.RegisterGradient("TensorListSetItem")
306def _TensorListSetItemGrad(op, dlist):
307  """Gradient function for TensorListSetItem."""
308  _, index, item = op.inputs
309  list_grad = gen_list_ops.tensor_list_set_item(
310      dlist, index=index, item=array_ops.zeros_like(item))
311  index_grad = None
312  element_grad = tensor_list_get_item(
313      dlist,
314      index,
315      element_shape=array_ops.shape(item),
316      element_dtype=item.dtype)
317  return list_grad, index_grad, element_grad
318
319
320@ops.RegisterGradient("TensorListResize")
321def _TensorListResizeGrad(op, dlist):
322  input_list, _ = op.inputs
323  input_list_size = gen_list_ops.tensor_list_length(input_list)
324  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
325
326
327@ops.RegisterGradient("TensorListGather")
328def _TensorListGatherGrad(op, dtensor):
329  """Gradient function for TensorListGather."""
330  input_list, indices, _ = op.inputs
331  element_shape = gen_list_ops.tensor_list_element_shape(
332      input_list, shape_type=dtypes.int32)
333  num_elements = gen_list_ops.tensor_list_length(input_list)
334  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
335  dlist = tensor_list_scatter(
336      tensor=dtensor, indices=indices, input_handle=dlist)
337  return dlist, None, None
338
339
340@ops.RegisterGradient("TensorListScatter")
341@ops.RegisterGradient("TensorListScatterV2")
342def _TensorListScatterGrad(op, dlist):
343  """Gradient function for TensorListScatter."""
344  tensor = op.inputs[0]
345  indices = op.inputs[1]
346  dtensor = gen_list_ops.tensor_list_gather(
347      dlist,
348      indices,
349      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
350      element_dtype=tensor.dtype)
351  if op.type == "TensorListScatterV2":
352    return dtensor, None, None, None
353  else:
354    return dtensor, None, None
355
356
357@ops.RegisterGradient("TensorListScatterIntoExistingList")
358def _TensorListScatterIntoExistingListGrad(op, dlist):
359  """Gradient function for TensorListScatterIntoExistingList."""
360  _, tensor, indices = op.inputs
361  dtensor = gen_list_ops.tensor_list_gather(
362      dlist,
363      indices,
364      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
365      element_dtype=tensor.dtype)
366  zeros = array_ops.zeros_like(tensor)
367  dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
368  return dlist, dtensor, None
369
370
371def _build_element_shape(shape):
372  """Converts shape to a format understood by list_ops for element_shape.
373
374  If `shape` is already a `Tensor` it is returned as-is. We do not perform a
375  type check here.
376
377  If shape is None or a TensorShape with unknown rank, -1 is returned.
378
379  If shape is a scalar, an int32 tensor with empty list is returned. Note we
380  do directly return an empty list since ops.convert_to_tensor would conver it
381  to a float32 which is not a valid type for element_shape.
382
383  If shape is a sequence of dims, None's in the list are replaced with -1. We
384  do not check the dtype of the other dims.
385
386  Args:
387    shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
388      be a None, scalar or Tensor).
389
390  Returns:
391    A None-free shape that can be converted to a tensor.
392  """
393  if isinstance(shape, ops.Tensor):
394    return shape
395  if isinstance(shape, tensor_shape.TensorShape):
396    # `TensorShape.as_list` requires rank to be known.
397    shape = shape.as_list() if shape else None
398  # Shape is unknown.
399  if shape is None:
400    return -1
401  # Shape is numpy array or a scalar.
402  if isinstance(shape, (np.ndarray, np.generic)) or not shape:
403    return ops.convert_to_tensor(shape, dtype=dtypes.int32)
404  # Shape is a sequence of dimensions. Convert None dims to -1.
405  def convert(val):
406    if val is None:
407      return -1
408    if isinstance(val, ops.Tensor):
409      return val
410    if isinstance(val, tensor_shape.Dimension):
411      return val.value if val.value is not None else -1
412    return val
413
414  return [convert(d) for d in shape]
415