1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to manipulate lists of tensors.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.core.framework import full_type_pb2 25from tensorflow.python.framework import cpp_shape_inference_pb2 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_list_ops 32from tensorflow.python.ops import handle_data_util 33# go/tf-wildcard-import 34# pylint: disable=wildcard-import 35from tensorflow.python.ops.gen_list_ops import * 36# pylint: enable=wildcard-import 37from tensorflow.python.util.lazy_loader import LazyLoader 38 39# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops 40control_flow_ops = LazyLoader( 41 "control_flow_ops", globals(), 42 "tensorflow.python.ops.control_flow_ops") 43 44 45ops.NotDifferentiable("TensorListConcatLists") 46ops.NotDifferentiable("TensorListElementShape") 47ops.NotDifferentiable("TensorListLength") 48ops.NotDifferentiable("TensorListPushBackBatch") 49 50 51def empty_tensor_list(element_shape, 52 element_dtype, 53 max_num_elements=None, 54 name=None): 55 if max_num_elements is None: 56 max_num_elements = -1 57 58 return gen_list_ops.empty_tensor_list( 59 element_shape=_build_element_shape(element_shape), 60 element_dtype=element_dtype, 61 max_num_elements=max_num_elements, 62 name=name) 63 64 65def _set_handle_data(list_handle, element_shape, element_dtype): 66 """Sets type information on `list_handle` for consistency with graphs.""" 67 # TODO(b/169968286): It would be better if we had a consistent story for 68 # creating handle data from eager operations (shared with VarHandleOp). 69 if isinstance(list_handle, ops.EagerTensor): 70 if tensor_util.is_tf_type(element_shape): 71 element_shape = tensor_shape.TensorShape(None) 72 elif not isinstance(element_shape, tensor_shape.TensorShape): 73 element_shape = tensor_shape.TensorShape(element_shape) 74 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 75 handle_data.is_set = True 76 # TODO(b/191472076): This duplicates type inference. Clean up. 77 handle_data.shape_and_type.append( 78 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 79 shape=element_shape.as_proto(), 80 dtype=element_dtype.as_datatype_enum, 81 type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY))) 82 list_handle._handle_data = handle_data # pylint: disable=protected-access 83 84 85def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): 86 result = gen_list_ops.tensor_list_reserve( 87 element_shape=_build_element_shape(element_shape), 88 num_elements=num_elements, 89 element_dtype=element_dtype, 90 name=name) 91 # TODO(b/169968286): gen_ops needs to ensure the metadata is properly 92 # populated for eager operations. 93 _set_handle_data(result, element_shape, element_dtype) 94 return result 95 96 97def tensor_list_from_tensor(tensor, element_shape, name=None): 98 tensor = ops.convert_to_tensor(tensor) 99 result = gen_list_ops.tensor_list_from_tensor( 100 tensor=tensor, 101 element_shape=_build_element_shape(element_shape), 102 name=name) 103 _set_handle_data(result, tensor.shape, tensor.dtype) 104 return result 105 106 107def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, 108 name=None): 109 return gen_list_ops.tensor_list_get_item( 110 input_handle=input_handle, 111 index=index, 112 element_shape=_build_element_shape(element_shape), 113 element_dtype=element_dtype, 114 name=name) 115 116 117def tensor_list_pop_back(input_handle, element_dtype, name=None): 118 return gen_list_ops.tensor_list_pop_back( 119 input_handle=input_handle, 120 element_shape=-1, 121 element_dtype=element_dtype, 122 name=name) 123 124 125def tensor_list_gather(input_handle, 126 indices, 127 element_dtype, 128 element_shape=None, 129 name=None): 130 return gen_list_ops.tensor_list_gather( 131 input_handle=input_handle, 132 indices=indices, 133 element_shape=_build_element_shape(element_shape), 134 element_dtype=element_dtype, 135 name=name) 136 137 138def tensor_list_scatter(tensor, 139 indices, 140 element_shape=None, 141 input_handle=None, 142 name=None): 143 """Returns a TensorList created or updated by scattering `tensor`.""" 144 tensor = ops.convert_to_tensor(tensor) 145 if input_handle is not None: 146 output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( 147 input_handle=input_handle, tensor=tensor, indices=indices, name=name) 148 handle_data_util.copy_handle_data(input_handle, output_handle) 149 return output_handle 150 else: 151 output_handle = gen_list_ops.tensor_list_scatter_v2( 152 tensor=tensor, 153 indices=indices, 154 element_shape=_build_element_shape(element_shape), 155 num_elements=-1, 156 name=name) 157 _set_handle_data(output_handle, element_shape, tensor.dtype) 158 return output_handle 159 160 161def tensor_list_stack(input_handle, 162 element_dtype, 163 num_elements=-1, 164 element_shape=None, 165 name=None): 166 return gen_list_ops.tensor_list_stack( 167 input_handle=input_handle, 168 element_shape=_build_element_shape(element_shape), 169 element_dtype=element_dtype, 170 num_elements=num_elements, 171 name=name) 172 173 174def tensor_list_concat(input_handle, element_dtype, element_shape=None, 175 name=None): 176 # Ignore the lengths output of TensorListConcat. It is only used during 177 # gradient computation. 178 return gen_list_ops.tensor_list_concat_v2( 179 input_handle=input_handle, 180 element_dtype=element_dtype, 181 element_shape=_build_element_shape(element_shape), 182 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64), 183 name=name)[0] 184 185 186def tensor_list_split(tensor, element_shape, lengths, name=None): 187 return gen_list_ops.tensor_list_split( 188 tensor=tensor, 189 element_shape=_build_element_shape(element_shape), 190 lengths=lengths, 191 name=name) 192 193 194def tensor_list_set_item(input_handle, 195 index, 196 item, 197 resize_if_index_out_of_bounds=False, 198 name=None): 199 """Sets `item` at `index` in input list.""" 200 if resize_if_index_out_of_bounds: 201 input_list_size = gen_list_ops.tensor_list_length(input_handle) 202 # TODO(srbs): This could cause some slowdown. Consider fusing resize 203 # functionality in the SetItem op. 204 input_handle = control_flow_ops.cond( 205 index >= input_list_size, 206 lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda 207 input_handle, index + 1), 208 lambda: input_handle) 209 output_handle = gen_list_ops.tensor_list_set_item( 210 input_handle=input_handle, index=index, item=item, name=name) 211 handle_data_util.copy_handle_data(input_handle, output_handle) 212 return output_handle 213 214 215@ops.RegisterGradient("TensorListPushBack") 216def _PushBackGrad(op, dresult): 217 return gen_list_ops.tensor_list_pop_back( 218 dresult, 219 element_shape=array_ops.shape(op.inputs[1]), 220 element_dtype=op.get_attr("element_dtype")) 221 222 223@ops.RegisterGradient("TensorListPopBack") 224def _PopBackGrad(op, dlist, delement): 225 if dlist is None: 226 dlist = empty_tensor_list( 227 element_dtype=delement.dtype, 228 element_shape=gen_list_ops.tensor_list_element_shape( 229 op.outputs[0], shape_type=dtypes.int32)) 230 if delement is None: 231 delement = array_ops.zeros_like(op.outputs[1]) 232 return gen_list_ops.tensor_list_push_back(dlist, delement), None 233 234 235@ops.RegisterGradient("TensorListStack") 236def _TensorListStackGrad(unused_op, dtensor): 237 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None 238 239 240@ops.RegisterGradient("TensorListConcat") 241@ops.RegisterGradient("TensorListConcatV2") 242def _TensorListConcatGrad(op, dtensor, unused_dlengths): 243 """Gradient function for TensorListConcat.""" 244 dlist = tensor_list_split( 245 dtensor, 246 element_shape=gen_list_ops.tensor_list_element_shape( 247 op.inputs[0], shape_type=dtypes.int32), 248 lengths=op.outputs[1]) 249 if op.type == "TensorListConcatV2": 250 return dlist, None, None 251 else: 252 return dlist 253 254 255@ops.RegisterGradient("TensorListSplit") 256def _TensorListSplitGrad(op, dlist): 257 tensor, _, lengths = op.inputs 258 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1]) 259 element_shape = array_ops.concat([[-1], element_shape], axis=0) 260 return gen_list_ops.tensor_list_concat_v2( 261 dlist, 262 element_shape=element_shape, 263 leading_dims=lengths, 264 element_dtype=op.inputs[0].dtype)[0], None, None 265 266 267@ops.RegisterGradient("TensorListFromTensor") 268def _TensorListFromTensorGrad(op, dlist): 269 """Gradient for TensorListFromTensor.""" 270 t = op.inputs[0] 271 if t.shape.dims and t.shape.dims[0].value is not None: 272 num_elements = t.shape.dims[0].value 273 else: 274 num_elements = None 275 if dlist is None: 276 dlist = empty_tensor_list( 277 element_dtype=t.dtype, 278 element_shape=gen_list_ops.tensor_list_element_shape( 279 op.outputs[0], shape_type=dtypes.int32)) 280 tensor_grad = gen_list_ops.tensor_list_stack( 281 dlist, 282 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]), 283 element_dtype=t.dtype, 284 num_elements=num_elements) 285 shape_grad = None 286 return tensor_grad, shape_grad 287 288 289@ops.RegisterGradient("TensorListGetItem") 290def _TensorListGetItemGrad(op, ditem): 291 """Gradient for TensorListGetItem.""" 292 list_size = gen_list_ops.tensor_list_length(op.inputs[0]) 293 list_grad = gen_list_ops.tensor_list_set_item( 294 gen_list_ops.tensor_list_reserve( 295 gen_list_ops.tensor_list_element_shape(op.inputs[0], 296 shape_type=dtypes.int32), 297 list_size, element_dtype=ditem.dtype), 298 index=op.inputs[1], 299 item=ditem) 300 index_grad = None 301 element_shape_grad = None 302 return list_grad, index_grad, element_shape_grad 303 304 305@ops.RegisterGradient("TensorListSetItem") 306def _TensorListSetItemGrad(op, dlist): 307 """Gradient function for TensorListSetItem.""" 308 _, index, item = op.inputs 309 list_grad = gen_list_ops.tensor_list_set_item( 310 dlist, index=index, item=array_ops.zeros_like(item)) 311 index_grad = None 312 element_grad = tensor_list_get_item( 313 dlist, 314 index, 315 element_shape=array_ops.shape(item), 316 element_dtype=item.dtype) 317 return list_grad, index_grad, element_grad 318 319 320@ops.RegisterGradient("TensorListResize") 321def _TensorListResizeGrad(op, dlist): 322 input_list, _ = op.inputs 323 input_list_size = gen_list_ops.tensor_list_length(input_list) 324 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None 325 326 327@ops.RegisterGradient("TensorListGather") 328def _TensorListGatherGrad(op, dtensor): 329 """Gradient function for TensorListGather.""" 330 input_list, indices, _ = op.inputs 331 element_shape = gen_list_ops.tensor_list_element_shape( 332 input_list, shape_type=dtypes.int32) 333 num_elements = gen_list_ops.tensor_list_length(input_list) 334 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype) 335 dlist = tensor_list_scatter( 336 tensor=dtensor, indices=indices, input_handle=dlist) 337 return dlist, None, None 338 339 340@ops.RegisterGradient("TensorListScatter") 341@ops.RegisterGradient("TensorListScatterV2") 342def _TensorListScatterGrad(op, dlist): 343 """Gradient function for TensorListScatter.""" 344 tensor = op.inputs[0] 345 indices = op.inputs[1] 346 dtensor = gen_list_ops.tensor_list_gather( 347 dlist, 348 indices, 349 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 350 element_dtype=tensor.dtype) 351 if op.type == "TensorListScatterV2": 352 return dtensor, None, None, None 353 else: 354 return dtensor, None, None 355 356 357@ops.RegisterGradient("TensorListScatterIntoExistingList") 358def _TensorListScatterIntoExistingListGrad(op, dlist): 359 """Gradient function for TensorListScatterIntoExistingList.""" 360 _, tensor, indices = op.inputs 361 dtensor = gen_list_ops.tensor_list_gather( 362 dlist, 363 indices, 364 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 365 element_dtype=tensor.dtype) 366 zeros = array_ops.zeros_like(tensor) 367 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist) 368 return dlist, dtensor, None 369 370 371def _build_element_shape(shape): 372 """Converts shape to a format understood by list_ops for element_shape. 373 374 If `shape` is already a `Tensor` it is returned as-is. We do not perform a 375 type check here. 376 377 If shape is None or a TensorShape with unknown rank, -1 is returned. 378 379 If shape is a scalar, an int32 tensor with empty list is returned. Note we 380 do directly return an empty list since ops.convert_to_tensor would conver it 381 to a float32 which is not a valid type for element_shape. 382 383 If shape is a sequence of dims, None's in the list are replaced with -1. We 384 do not check the dtype of the other dims. 385 386 Args: 387 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could 388 be a None, scalar or Tensor). 389 390 Returns: 391 A None-free shape that can be converted to a tensor. 392 """ 393 if isinstance(shape, ops.Tensor): 394 return shape 395 if isinstance(shape, tensor_shape.TensorShape): 396 # `TensorShape.as_list` requires rank to be known. 397 shape = shape.as_list() if shape else None 398 # Shape is unknown. 399 if shape is None: 400 return -1 401 # Shape is numpy array or a scalar. 402 if isinstance(shape, (np.ndarray, np.generic)) or not shape: 403 return ops.convert_to_tensor(shape, dtype=dtypes.int32) 404 # Shape is a sequence of dimensions. Convert None dims to -1. 405 def convert(val): 406 if val is None: 407 return -1 408 if isinstance(val, ops.Tensor): 409 return val 410 if isinstance(val, tensor_shape.Dimension): 411 return val.value if val.value is not None else -1 412 return val 413 414 return [convert(d) for d in shape] 415