• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to manipulate lists of tensors."""
16
17# pylint: disable=g-bad-name
18import numpy as np
19
20from tensorflow.core.framework import full_type_pb2
21from tensorflow.python.framework import cpp_shape_inference_pb2
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_list_ops
28from tensorflow.python.ops import handle_data_util
29# go/tf-wildcard-import
30# pylint: disable=wildcard-import
31from tensorflow.python.ops.gen_list_ops import *
32# pylint: enable=wildcard-import
33from tensorflow.python.util.lazy_loader import LazyLoader
34
35# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
36control_flow_ops = LazyLoader(
37    "control_flow_ops", globals(),
38    "tensorflow.python.ops.control_flow_ops")
39
40
41ops.NotDifferentiable("TensorListConcatLists")
42ops.NotDifferentiable("TensorListElementShape")
43ops.NotDifferentiable("TensorListLength")
44ops.NotDifferentiable("TensorListPushBackBatch")
45
46
47def empty_tensor_list(element_shape,
48                      element_dtype,
49                      max_num_elements=None,
50                      name=None):
51  if max_num_elements is None:
52    max_num_elements = -1
53
54  return gen_list_ops.empty_tensor_list(
55      element_shape=_build_element_shape(element_shape),
56      element_dtype=element_dtype,
57      max_num_elements=max_num_elements,
58      name=name)
59
60
61def _set_handle_data(list_handle, element_shape, element_dtype):
62  """Sets type information on `list_handle` for consistency with graphs."""
63  # TODO(b/169968286): It would be better if we had a consistent story for
64  # creating handle data from eager operations (shared with VarHandleOp).
65  if isinstance(list_handle, ops.EagerTensor):
66    if tensor_util.is_tf_type(element_shape):
67      element_shape = tensor_shape.TensorShape(None)
68    elif not isinstance(element_shape, tensor_shape.TensorShape):
69      element_shape = tensor_shape.TensorShape(element_shape)
70    handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
71    handle_data.is_set = True
72    # TODO(b/191472076): This duplicates type inference. Clean up.
73    handle_data.shape_and_type.append(
74        cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
75            shape=element_shape.as_proto(),
76            dtype=element_dtype.as_datatype_enum,
77            type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)))
78    list_handle._handle_data = handle_data  # pylint: disable=protected-access
79
80
81def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
82  result = gen_list_ops.tensor_list_reserve(
83      element_shape=_build_element_shape(element_shape),
84      num_elements=num_elements,
85      element_dtype=element_dtype,
86      name=name)
87  # TODO(b/169968286): gen_ops needs to ensure the metadata is properly
88  # populated for eager operations.
89  _set_handle_data(result, element_shape, element_dtype)
90  return result
91
92
93def tensor_list_from_tensor(tensor, element_shape, name=None):
94  tensor = ops.convert_to_tensor(tensor)
95  result = gen_list_ops.tensor_list_from_tensor(
96      tensor=tensor,
97      element_shape=_build_element_shape(element_shape),
98      name=name)
99  _set_handle_data(result, tensor.shape, tensor.dtype)
100  return result
101
102
103def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
104                         name=None):
105  return gen_list_ops.tensor_list_get_item(
106      input_handle=input_handle,
107      index=index,
108      element_shape=_build_element_shape(element_shape),
109      element_dtype=element_dtype,
110      name=name)
111
112
113def tensor_list_pop_back(input_handle, element_dtype, name=None):
114  return gen_list_ops.tensor_list_pop_back(
115      input_handle=input_handle,
116      element_shape=-1,
117      element_dtype=element_dtype,
118      name=name)
119
120
121def tensor_list_gather(input_handle,
122                       indices,
123                       element_dtype,
124                       element_shape=None,
125                       name=None):
126  return gen_list_ops.tensor_list_gather(
127      input_handle=input_handle,
128      indices=indices,
129      element_shape=_build_element_shape(element_shape),
130      element_dtype=element_dtype,
131      name=name)
132
133
134def tensor_list_scatter(tensor,
135                        indices,
136                        element_shape=None,
137                        input_handle=None,
138                        name=None):
139  """Returns a TensorList created or updated by scattering `tensor`."""
140  tensor = ops.convert_to_tensor(tensor)
141  if input_handle is not None:
142    output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(
143        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
144    handle_data_util.copy_handle_data(input_handle, output_handle)
145    return output_handle
146  else:
147    output_handle = gen_list_ops.tensor_list_scatter_v2(
148        tensor=tensor,
149        indices=indices,
150        element_shape=_build_element_shape(element_shape),
151        num_elements=-1,
152        name=name)
153    _set_handle_data(output_handle, element_shape, tensor.dtype)
154    return output_handle
155
156
157def tensor_list_stack(input_handle,
158                      element_dtype,
159                      num_elements=-1,
160                      element_shape=None,
161                      name=None):
162  return gen_list_ops.tensor_list_stack(
163      input_handle=input_handle,
164      element_shape=_build_element_shape(element_shape),
165      element_dtype=element_dtype,
166      num_elements=num_elements,
167      name=name)
168
169
170def tensor_list_concat(input_handle, element_dtype, element_shape=None,
171                       name=None):
172  # Ignore the lengths output of TensorListConcat. It is only used during
173  # gradient computation.
174  return gen_list_ops.tensor_list_concat_v2(
175      input_handle=input_handle,
176      element_dtype=element_dtype,
177      element_shape=_build_element_shape(element_shape),
178      leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
179      name=name)[0]
180
181
182def tensor_list_split(tensor, element_shape, lengths, name=None):
183  return gen_list_ops.tensor_list_split(
184      tensor=tensor,
185      element_shape=_build_element_shape(element_shape),
186      lengths=lengths,
187      name=name)
188
189
190def tensor_list_set_item(input_handle,
191                         index,
192                         item,
193                         resize_if_index_out_of_bounds=False,
194                         name=None):
195  """Sets `item` at `index` in input list."""
196  if resize_if_index_out_of_bounds:
197    input_list_size = gen_list_ops.tensor_list_length(input_handle)
198    # TODO(srbs): This could cause some slowdown. Consider fusing resize
199    # functionality in the SetItem op.
200    input_handle = control_flow_ops.cond(
201        index >= input_list_size,
202        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
203            input_handle, index + 1),
204        lambda: input_handle)
205  output_handle = gen_list_ops.tensor_list_set_item(
206      input_handle=input_handle, index=index, item=item, name=name)
207  handle_data_util.copy_handle_data(input_handle, output_handle)
208  return output_handle
209
210
211@ops.RegisterGradient("TensorListPushBack")
212def _PushBackGrad(op, dresult):
213  return gen_list_ops.tensor_list_pop_back(
214      dresult,
215      element_shape=array_ops.shape(op.inputs[1]),
216      element_dtype=op.get_attr("element_dtype"))
217
218
219@ops.RegisterGradient("TensorListPopBack")
220def _PopBackGrad(op, dlist, delement):
221  if dlist is None:
222    dlist = empty_tensor_list(
223        element_dtype=delement.dtype,
224        element_shape=gen_list_ops.tensor_list_element_shape(
225            op.outputs[0], shape_type=dtypes.int32))
226  if delement is None:
227    delement = array_ops.zeros_like(op.outputs[1])
228  return gen_list_ops.tensor_list_push_back(dlist, delement), None
229
230
231@ops.RegisterGradient("TensorListStack")
232def _TensorListStackGrad(unused_op, dtensor):
233  return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
234
235
236@ops.RegisterGradient("TensorListConcat")
237@ops.RegisterGradient("TensorListConcatV2")
238def _TensorListConcatGrad(op, dtensor, unused_dlengths):
239  """Gradient function for TensorListConcat."""
240  dlist = tensor_list_split(
241      dtensor,
242      element_shape=gen_list_ops.tensor_list_element_shape(
243          op.inputs[0], shape_type=dtypes.int32),
244      lengths=op.outputs[1])
245  if op.type == "TensorListConcatV2":
246    return dlist, None, None
247  else:
248    return dlist
249
250
251@ops.RegisterGradient("TensorListSplit")
252def _TensorListSplitGrad(op, dlist):
253  tensor, _, lengths = op.inputs
254  element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
255  element_shape = array_ops.concat([[-1], element_shape], axis=0)
256  return gen_list_ops.tensor_list_concat_v2(
257      dlist,
258      element_shape=element_shape,
259      leading_dims=lengths,
260      element_dtype=op.inputs[0].dtype)[0], None, None
261
262
263@ops.RegisterGradient("TensorListFromTensor")
264def _TensorListFromTensorGrad(op, dlist):
265  """Gradient for TensorListFromTensor."""
266  t = op.inputs[0]
267  if t.shape.dims and t.shape.dims[0].value is not None:
268    num_elements = t.shape.dims[0].value
269  else:
270    num_elements = None
271  if dlist is None:
272    dlist = empty_tensor_list(
273        element_dtype=t.dtype,
274        element_shape=gen_list_ops.tensor_list_element_shape(
275            op.outputs[0], shape_type=dtypes.int32))
276  tensor_grad = gen_list_ops.tensor_list_stack(
277      dlist,
278      element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
279      element_dtype=t.dtype,
280      num_elements=num_elements)
281  shape_grad = None
282  return tensor_grad, shape_grad
283
284
285@ops.RegisterGradient("TensorListGetItem")
286def _TensorListGetItemGrad(op, ditem):
287  """Gradient for TensorListGetItem."""
288  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
289  list_grad = gen_list_ops.tensor_list_set_item(
290      gen_list_ops.tensor_list_reserve(
291          gen_list_ops.tensor_list_element_shape(op.inputs[0],
292                                                 shape_type=dtypes.int32),
293          list_size, element_dtype=ditem.dtype),
294      index=op.inputs[1],
295      item=ditem)
296  index_grad = None
297  element_shape_grad = None
298  return list_grad, index_grad, element_shape_grad
299
300
301@ops.RegisterGradient("TensorListSetItem")
302def _TensorListSetItemGrad(op, dlist):
303  """Gradient function for TensorListSetItem."""
304  _, index, item = op.inputs
305  list_grad = gen_list_ops.tensor_list_set_item(
306      dlist, index=index, item=array_ops.zeros_like(item))
307  index_grad = None
308  element_grad = tensor_list_get_item(
309      dlist,
310      index,
311      element_shape=array_ops.shape(item),
312      element_dtype=item.dtype)
313  return list_grad, index_grad, element_grad
314
315
316@ops.RegisterGradient("TensorListResize")
317def _TensorListResizeGrad(op, dlist):
318  input_list, _ = op.inputs
319  input_list_size = gen_list_ops.tensor_list_length(input_list)
320  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
321
322
323@ops.RegisterGradient("TensorListGather")
324def _TensorListGatherGrad(op, dtensor):
325  """Gradient function for TensorListGather."""
326  input_list, indices, _ = op.inputs
327  element_shape = gen_list_ops.tensor_list_element_shape(
328      input_list, shape_type=dtypes.int32)
329  num_elements = gen_list_ops.tensor_list_length(input_list)
330  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
331  dlist = tensor_list_scatter(
332      tensor=dtensor, indices=indices, input_handle=dlist)
333  return dlist, None, None
334
335
336@ops.RegisterGradient("TensorListScatter")
337@ops.RegisterGradient("TensorListScatterV2")
338def _TensorListScatterGrad(op, dlist):
339  """Gradient function for TensorListScatter."""
340  tensor = op.inputs[0]
341  indices = op.inputs[1]
342  dtensor = gen_list_ops.tensor_list_gather(
343      dlist,
344      indices,
345      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
346      element_dtype=tensor.dtype)
347  if op.type == "TensorListScatterV2":
348    return dtensor, None, None, None
349  else:
350    return dtensor, None, None
351
352
353@ops.RegisterGradient("TensorListScatterIntoExistingList")
354def _TensorListScatterIntoExistingListGrad(op, dlist):
355  """Gradient function for TensorListScatterIntoExistingList."""
356  _, tensor, indices = op.inputs
357  dtensor = gen_list_ops.tensor_list_gather(
358      dlist,
359      indices,
360      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
361      element_dtype=tensor.dtype)
362  zeros = array_ops.zeros_like(tensor)
363  dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
364  return dlist, dtensor, None
365
366
367def _build_element_shape(shape):
368  """Converts shape to a format understood by list_ops for element_shape.
369
370  If `shape` is already a `Tensor` it is returned as-is. We do not perform a
371  type check here.
372
373  If shape is None or a TensorShape with unknown rank, -1 is returned.
374
375  If shape is a scalar, an int32 tensor with empty list is returned. Note we
376  do directly return an empty list since ops.convert_to_tensor would conver it
377  to a float32 which is not a valid type for element_shape.
378
379  If shape is a sequence of dims, None's in the list are replaced with -1. We
380  do not check the dtype of the other dims.
381
382  Args:
383    shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
384      be a None, scalar or Tensor).
385
386  Returns:
387    A None-free shape that can be converted to a tensor.
388  """
389  if isinstance(shape, ops.Tensor):
390    return shape
391  if isinstance(shape, tensor_shape.TensorShape):
392    # `TensorShape.as_list` requires rank to be known.
393    shape = shape.as_list() if shape else None
394  # Shape is unknown.
395  if shape is None:
396    return -1
397  # Shape is numpy array or a scalar.
398  if isinstance(shape, (np.ndarray, np.generic)) or not shape:
399    return ops.convert_to_tensor(shape, dtype=dtypes.int32)
400  # Shape is a sequence of dimensions. Convert None dims to -1.
401  def convert(val):
402    if val is None:
403      return -1
404    if isinstance(val, ops.Tensor):
405      return val
406    if isinstance(val, tensor_shape.Dimension):
407      return val.value if val.value is not None else -1
408    return val
409
410  return [convert(d) for d in shape]
411