• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.compiler.tests.unstack."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.compiler.tests import xla_test
25from tensorflow.python.ops import array_ops
26from tensorflow.python.platform import test
27
28
29class UnstackOpTest(xla_test.XLATestCase, parameterized.TestCase):
30
31  def _test(self, size):
32    with self.session() as sess:
33      x_tf = array_ops.placeholder(np.float32, shape=[size, 512])
34      with self.test_scope():
35        ret = array_ops.unstack(x_tf)
36      ret_vals = sess.run([ret], feed_dict={x_tf: np.zeros([size, 512])})
37      self.assertLen(ret_vals[0], size)
38      for ret_val in ret_vals[0]:
39        self.assertTrue(np.all(ret_val == 0.))
40
41  def testLarge2000(self):
42    self._test(2000)
43
44
45if __name__ == "__main__":
46  test.main()
47