1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators specific to slicing operations.""" 16 17import collections 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import tensor_util 21from tensorflow.python.ops import gen_array_ops 22from tensorflow.python.ops import gen_string_ops 23from tensorflow.python.ops import list_ops 24from tensorflow.python.ops import tensor_array_ops 25 26 27# TODO(mdan): Support extended slices. 28 29 30class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): 31 pass 32 33 34def get_item(target, i, opts): 35 """The slice read operator (i.e. __getitem__). 36 37 Note: it is unspecified whether target will be mutated or not. In general, 38 if target is mutable (like Python lists), it will be mutated. 39 40 Args: 41 target: An entity that supports getitem semantics. 42 i: Index to read from. 43 opts: A GetItemOpts object. 44 45 Returns: 46 The read element. 47 48 Raises: 49 ValueError: if target is not of a supported type. 50 """ 51 assert isinstance(opts, GetItemOpts) 52 53 if isinstance(target, tensor_array_ops.TensorArray): 54 return _tf_tensorarray_get_item(target, i) 55 elif tensor_util.is_tf_type(target): 56 if target.dtype == dtypes.variant: 57 return _tf_tensor_list_get_item(target, i, opts) 58 elif target.dtype == dtypes.string and target.shape.ndims == 0: 59 return _tf_tensor_string_get_item(target, i) 60 else: 61 return _tf_tensor_get_item(target, i) 62 else: 63 return _py_get_item(target, i) 64 65 66def _tf_tensorarray_get_item(target, i): 67 """Overload of get_item that stages a TensorArray read.""" 68 return target.read(i) 69 70 71def _tf_tensor_list_get_item(target, i, opts): 72 """Overload of get_item that stages a Tensor list read.""" 73 if opts.element_dtype is None: 74 raise ValueError('cannot retrieve from a list without knowing its ' 75 'element type; use set_element_type to annotate it') 76 x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) 77 return x 78 79 80def _tf_tensor_get_item(target, i): 81 """Overload of get_item that stages a Tensor (not Tensor list) read.""" 82 return target[i] 83 84 85def _tf_tensor_string_get_item(target, i): 86 """Overload of get_item that stages a Tensor string read.""" 87 x = gen_string_ops.substr(target, i, 1) 88 return x 89 90 91def _py_get_item(target, i): 92 """Overload of get_item that executes a Python list modification.""" 93 return target[i] 94 95 96def set_item(target, i, x): 97 """The slice write operator (i.e. __setitem__). 98 99 Note: it is unspecified whether target will be mutated or not. In general, 100 if target is mutable (like Python lists), it will be mutated. 101 102 Args: 103 target: An entity that supports setitem semantics. 104 i: Index to modify. 105 x: The new element value. 106 107 Returns: 108 Same as target, after the update was performed. 109 110 Raises: 111 ValueError: if target is not of a supported type. 112 """ 113 if isinstance(target, tensor_array_ops.TensorArray): 114 return _tf_tensorarray_set_item(target, i, x) 115 elif tensor_util.is_tf_type(target): 116 if target.dtype == dtypes.variant: 117 return _tf_tensor_list_set_item(target, i, x) 118 else: 119 return _tf_tensor_set_item(target, i, x) 120 else: 121 return _py_set_item(target, i, x) 122 123 124def _tf_tensorarray_set_item(target, i, x): 125 """Overload of set_item that stages a TensorArray write.""" 126 return target.write(i, x) 127 128 129def _tf_tensor_list_set_item(target, i, x): 130 """Overload of set_item that stages a Tensor list update.""" 131 return list_ops.tensor_list_set_item(target, i, x) 132 133 134def _tf_tensor_set_item(target, i, x): 135 """Overload of set_item that stages a Tensor scatter update.""" 136 return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,)) 137 138 139def _py_set_item(target, i, x): 140 """Overload of set_item that executes a Python list modification.""" 141 target[i] = x 142 return target 143