1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for slices module.""" 16 17from tensorflow.python.autograph.operators import slices 18from tensorflow.python.framework import constant_op 19from tensorflow.python.ops import list_ops 20from tensorflow.python.platform import test 21 22 23class SlicesTest(test.TestCase): 24 25 def test_set_item_tensor_list(self): 26 initial_list = constant_op.constant([[1, 2], [3, 4]]) 27 elem_shape = constant_op.constant([2]) 28 l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) 29 l = slices.set_item(l, 0, [5, 6]) 30 31 with self.cached_session() as sess: 32 t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) 33 self.assertAllEqual(self.evaluate(t), [[5, 6], [3, 4]]) 34 35 def test_get_item_tensor_list(self): 36 initial_list = constant_op.constant([[1, 2], [3, 4]]) 37 elem_shape = constant_op.constant([2]) 38 l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) 39 t = slices.get_item( 40 l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) 41 42 with self.cached_session() as sess: 43 self.assertAllEqual(self.evaluate(t), [3, 4]) 44 45 def test_get_item_tensor_string(self): 46 initial_str = constant_op.constant('abcd') 47 t = slices.get_item(initial_str, 1, 48 slices.GetItemOpts(element_dtype=initial_str.dtype)) 49 50 with self.cached_session() as sess: 51 self.assertEqual(self.evaluate(t), b'b') 52 53 initial_list_str = constant_op.constant(['abcd', 'bcde']) 54 t = slices.get_item(initial_list_str, 1, 55 slices.GetItemOpts(element_dtype=initial_str.dtype)) 56 57 with self.cached_session() as sess: 58 self.assertEqual(self.evaluate(t), b'bcde') 59 60 61if __name__ == '__main__': 62 test.main() 63