1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to manipulate lists of tensors.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.core.framework import types_pb2 25from tensorflow.python.framework import cpp_shape_inference_pb2 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_list_ops 32from tensorflow.python.ops import handle_data_util 33# go/tf-wildcard-import 34# pylint: disable=wildcard-import 35from tensorflow.python.ops.gen_list_ops import * 36# pylint: enable=wildcard-import 37from tensorflow.python.util.lazy_loader import LazyLoader 38 39# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops 40control_flow_ops = LazyLoader( 41 "control_flow_ops", globals(), 42 "tensorflow.python.ops.control_flow_ops") 43 44 45ops.NotDifferentiable("TensorListConcatLists") 46ops.NotDifferentiable("TensorListElementShape") 47ops.NotDifferentiable("TensorListLength") 48ops.NotDifferentiable("TensorListPushBackBatch") 49 50 51def empty_tensor_list(element_shape, 52 element_dtype, 53 max_num_elements=None, 54 name=None): 55 if max_num_elements is None: 56 max_num_elements = -1 57 58 return gen_list_ops.empty_tensor_list( 59 element_shape=_build_element_shape(element_shape), 60 element_dtype=element_dtype, 61 max_num_elements=max_num_elements, 62 name=name) 63 64 65def _set_handle_data(list_handle, element_shape, element_dtype): 66 """Sets type information on `list_handle` for consistency with graphs.""" 67 # TODO(b/169968286): It would be better if we had a consistent story for 68 # creating handle data from eager operations (shared with VarHandleOp). 69 if isinstance(list_handle, ops.EagerTensor): 70 if tensor_util.is_tf_type(element_shape): 71 element_shape = tensor_shape.TensorShape(None) 72 elif not isinstance(element_shape, tensor_shape.TensorShape): 73 element_shape = tensor_shape.TensorShape(element_shape) 74 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 75 handle_data.is_set = True 76 handle_data.shape_and_type.append( 77 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 78 shape=element_shape.as_proto(), 79 dtype=element_dtype.as_datatype_enum, 80 specialized_type=types_pb2.ST_TENSOR_LIST)) 81 list_handle._handle_data = handle_data # pylint: disable=protected-access 82 83 84def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): 85 result = gen_list_ops.tensor_list_reserve( 86 element_shape=_build_element_shape(element_shape), 87 num_elements=num_elements, 88 element_dtype=element_dtype, 89 name=name) 90 # TODO(b/169968286): gen_ops needs to ensure the metadata is properly 91 # populated for eager operations. 92 _set_handle_data(result, element_shape, element_dtype) 93 return result 94 95 96def tensor_list_from_tensor(tensor, element_shape, name=None): 97 tensor = ops.convert_to_tensor(tensor) 98 result = gen_list_ops.tensor_list_from_tensor( 99 tensor=tensor, 100 element_shape=_build_element_shape(element_shape), 101 name=name) 102 _set_handle_data(result, tensor.shape, tensor.dtype) 103 return result 104 105 106def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, 107 name=None): 108 return gen_list_ops.tensor_list_get_item( 109 input_handle=input_handle, 110 index=index, 111 element_shape=_build_element_shape(element_shape), 112 element_dtype=element_dtype, 113 name=name) 114 115 116def tensor_list_pop_back(input_handle, element_dtype, name=None): 117 return gen_list_ops.tensor_list_pop_back( 118 input_handle=input_handle, 119 element_shape=-1, 120 element_dtype=element_dtype, 121 name=name) 122 123 124def tensor_list_gather(input_handle, 125 indices, 126 element_dtype, 127 element_shape=None, 128 name=None): 129 return gen_list_ops.tensor_list_gather( 130 input_handle=input_handle, 131 indices=indices, 132 element_shape=_build_element_shape(element_shape), 133 element_dtype=element_dtype, 134 name=name) 135 136 137def tensor_list_scatter(tensor, 138 indices, 139 element_shape=None, 140 input_handle=None, 141 name=None): 142 """Returns a TensorList created or updated by scattering `tensor`.""" 143 tensor = ops.convert_to_tensor(tensor) 144 if input_handle is not None: 145 output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( 146 input_handle=input_handle, tensor=tensor, indices=indices, name=name) 147 handle_data_util.copy_handle_data(input_handle, output_handle) 148 return output_handle 149 else: 150 output_handle = gen_list_ops.tensor_list_scatter_v2( 151 tensor=tensor, 152 indices=indices, 153 element_shape=_build_element_shape(element_shape), 154 num_elements=-1, 155 name=name) 156 _set_handle_data(output_handle, element_shape, tensor.dtype) 157 return output_handle 158 159 160def tensor_list_stack(input_handle, 161 element_dtype, 162 num_elements=-1, 163 element_shape=None, 164 name=None): 165 return gen_list_ops.tensor_list_stack( 166 input_handle=input_handle, 167 element_shape=_build_element_shape(element_shape), 168 element_dtype=element_dtype, 169 num_elements=num_elements, 170 name=name) 171 172 173def tensor_list_concat(input_handle, element_dtype, element_shape=None, 174 name=None): 175 # Ignore the lengths output of TensorListConcat. It is only used during 176 # gradient computation. 177 return gen_list_ops.tensor_list_concat_v2( 178 input_handle=input_handle, 179 element_dtype=element_dtype, 180 element_shape=_build_element_shape(element_shape), 181 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64), 182 name=name)[0] 183 184 185def tensor_list_split(tensor, element_shape, lengths, name=None): 186 return gen_list_ops.tensor_list_split( 187 tensor=tensor, 188 element_shape=_build_element_shape(element_shape), 189 lengths=lengths, 190 name=name) 191 192 193def tensor_list_set_item(input_handle, 194 index, 195 item, 196 resize_if_index_out_of_bounds=False, 197 name=None): 198 """Sets `item` at `index` in input list.""" 199 if resize_if_index_out_of_bounds: 200 input_list_size = gen_list_ops.tensor_list_length(input_handle) 201 # TODO(srbs): This could cause some slowdown. Consider fusing resize 202 # functionality in the SetItem op. 203 input_handle = control_flow_ops.cond( 204 index >= input_list_size, 205 lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda 206 input_handle, index + 1), 207 lambda: input_handle) 208 output_handle = gen_list_ops.tensor_list_set_item( 209 input_handle=input_handle, index=index, item=item, name=name) 210 handle_data_util.copy_handle_data(input_handle, output_handle) 211 return output_handle 212 213 214@ops.RegisterGradient("TensorListPushBack") 215def _PushBackGrad(op, dresult): 216 return gen_list_ops.tensor_list_pop_back( 217 dresult, 218 element_shape=array_ops.shape(op.inputs[1]), 219 element_dtype=op.get_attr("element_dtype")) 220 221 222@ops.RegisterGradient("TensorListPopBack") 223def _PopBackGrad(op, dlist, delement): 224 if dlist is None: 225 dlist = empty_tensor_list( 226 element_dtype=delement.dtype, 227 element_shape=gen_list_ops.tensor_list_element_shape( 228 op.outputs[0], shape_type=dtypes.int32)) 229 if delement is None: 230 delement = array_ops.zeros_like(op.outputs[1]) 231 return gen_list_ops.tensor_list_push_back(dlist, delement), None 232 233 234@ops.RegisterGradient("TensorListStack") 235def _TensorListStackGrad(unused_op, dtensor): 236 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None 237 238 239@ops.RegisterGradient("TensorListConcat") 240@ops.RegisterGradient("TensorListConcatV2") 241def _TensorListConcatGrad(op, dtensor, unused_dlengths): 242 """Gradient function for TensorListConcat.""" 243 dlist = tensor_list_split( 244 dtensor, 245 element_shape=gen_list_ops.tensor_list_element_shape( 246 op.inputs[0], shape_type=dtypes.int32), 247 lengths=op.outputs[1]) 248 if op.type == "TensorListConcatV2": 249 return dlist, None, None 250 else: 251 return dlist 252 253 254@ops.RegisterGradient("TensorListSplit") 255def _TensorListSplitGrad(op, dlist): 256 tensor, _, lengths = op.inputs 257 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1]) 258 element_shape = array_ops.concat([[-1], element_shape], axis=0) 259 return gen_list_ops.tensor_list_concat_v2( 260 dlist, 261 element_shape=element_shape, 262 leading_dims=lengths, 263 element_dtype=op.inputs[0].dtype)[0], None, None 264 265 266@ops.RegisterGradient("TensorListFromTensor") 267def _TensorListFromTensorGrad(op, dlist): 268 """Gradient for TensorListFromTensor.""" 269 t = op.inputs[0] 270 if t.shape.dims and t.shape.dims[0].value is not None: 271 num_elements = t.shape.dims[0].value 272 else: 273 num_elements = None 274 if dlist is None: 275 dlist = empty_tensor_list( 276 element_dtype=t.dtype, 277 element_shape=gen_list_ops.tensor_list_element_shape( 278 op.outputs[0], shape_type=dtypes.int32)) 279 tensor_grad = gen_list_ops.tensor_list_stack( 280 dlist, 281 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]), 282 element_dtype=t.dtype, 283 num_elements=num_elements) 284 shape_grad = None 285 return tensor_grad, shape_grad 286 287 288@ops.RegisterGradient("TensorListGetItem") 289def _TensorListGetItemGrad(op, ditem): 290 """Gradient for TensorListGetItem.""" 291 list_size = gen_list_ops.tensor_list_length(op.inputs[0]) 292 list_grad = gen_list_ops.tensor_list_set_item( 293 gen_list_ops.tensor_list_reserve( 294 gen_list_ops.tensor_list_element_shape(op.inputs[0], 295 shape_type=dtypes.int32), 296 list_size, element_dtype=ditem.dtype), 297 index=op.inputs[1], 298 item=ditem) 299 index_grad = None 300 element_shape_grad = None 301 return list_grad, index_grad, element_shape_grad 302 303 304@ops.RegisterGradient("TensorListSetItem") 305def _TensorListSetItemGrad(op, dlist): 306 """Gradient function for TensorListSetItem.""" 307 _, index, item = op.inputs 308 list_grad = gen_list_ops.tensor_list_set_item( 309 dlist, index=index, item=array_ops.zeros_like(item)) 310 index_grad = None 311 element_grad = tensor_list_get_item( 312 dlist, 313 index, 314 element_shape=array_ops.shape(item), 315 element_dtype=item.dtype) 316 return list_grad, index_grad, element_grad 317 318 319@ops.RegisterGradient("TensorListResize") 320def _TensorListResizeGrad(op, dlist): 321 input_list, _ = op.inputs 322 input_list_size = gen_list_ops.tensor_list_length(input_list) 323 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None 324 325 326@ops.RegisterGradient("TensorListGather") 327def _TensorListGatherGrad(op, dtensor): 328 """Gradient function for TensorListGather.""" 329 input_list, indices, _ = op.inputs 330 element_shape = gen_list_ops.tensor_list_element_shape( 331 input_list, shape_type=dtypes.int32) 332 num_elements = gen_list_ops.tensor_list_length(input_list) 333 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype) 334 dlist = tensor_list_scatter( 335 tensor=dtensor, indices=indices, input_handle=dlist) 336 return dlist, None, None 337 338 339@ops.RegisterGradient("TensorListScatter") 340@ops.RegisterGradient("TensorListScatterV2") 341def _TensorListScatterGrad(op, dlist): 342 """Gradient function for TensorListScatter.""" 343 tensor = op.inputs[0] 344 indices = op.inputs[1] 345 dtensor = gen_list_ops.tensor_list_gather( 346 dlist, 347 indices, 348 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 349 element_dtype=tensor.dtype) 350 if op.type == "TensorListScatterV2": 351 return dtensor, None, None, None 352 else: 353 return dtensor, None, None 354 355 356@ops.RegisterGradient("TensorListScatterIntoExistingList") 357def _TensorListScatterIntoExistingListGrad(op, dlist): 358 """Gradient function for TensorListScatterIntoExistingList.""" 359 _, tensor, indices = op.inputs 360 dtensor = gen_list_ops.tensor_list_gather( 361 dlist, 362 indices, 363 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 364 element_dtype=tensor.dtype) 365 zeros = array_ops.zeros_like(tensor) 366 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist) 367 return dlist, dtensor, None 368 369 370def _build_element_shape(shape): 371 """Converts shape to a format understood by list_ops for element_shape. 372 373 If `shape` is already a `Tensor` it is returned as-is. We do not perform a 374 type check here. 375 376 If shape is None or a TensorShape with unknown rank, -1 is returned. 377 378 If shape is a scalar, an int32 tensor with empty list is returned. Note we 379 do directly return an empty list since ops.convert_to_tensor would conver it 380 to a float32 which is not a valid type for element_shape. 381 382 If shape is a sequence of dims, None's in the list are replaced with -1. We 383 do not check the dtype of the other dims. 384 385 Args: 386 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could 387 be a None, scalar or Tensor). 388 389 Returns: 390 A None-free shape that can be converted to a tensor. 391 """ 392 if isinstance(shape, ops.Tensor): 393 return shape 394 if isinstance(shape, tensor_shape.TensorShape): 395 # `TensorShape.as_list` requires rank to be known. 396 shape = shape.as_list() if shape else None 397 # Shape is unknown. 398 if shape is None: 399 return -1 400 # Shape is numpy array or a scalar. 401 if isinstance(shape, (np.ndarray, np.generic)) or not shape: 402 return ops.convert_to_tensor(shape, dtype=dtypes.int32) 403 # Shape is a sequence of dimensions. Convert None dims to -1. 404 def convert(val): 405 if val is None: 406 return -1 407 if isinstance(val, ops.Tensor): 408 return val 409 if isinstance(val, tensor_shape.Dimension): 410 return val.value if val.value is not None else -1 411 return val 412 413 return [convert(d) for d in shape] 414