• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for tensor_util."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import sys
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import func_graph
30from tensorflow.python.framework import indexed_slices
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.framework import test_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import gen_state_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.ops.ragged import ragged_factory_ops
41from tensorflow.python.platform import test
42
43
44@test_util.run_all_in_graph_and_eager_modes
45class TensorUtilTest(test.TestCase, parameterized.TestCase):
46
47  def testFloat(self):
48    value = 10.0
49    t = tensor_util.make_tensor_proto(value)
50    self.assertProtoEquals("""
51      dtype: DT_FLOAT
52      tensor_shape {}
53      float_val: %.1f
54      """ % value, t)
55    a = tensor_util.MakeNdarray(t)
56    self.assertEqual(np.float32, a.dtype)
57    self.assertAllClose(np.array(value, dtype=np.float32), a)
58
59  def testFloatN(self):
60    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
61    if sys.byteorder == "big":
62      self.assertProtoEquals(r"""
63        dtype: DT_FLOAT
64        tensor_shape { dim { size: 3 } }
65        tensor_content: "A \000\000A\240\000\000A\360\000\000"
66        """, t)
67    else:
68      self.assertProtoEquals(r"""
69        dtype: DT_FLOAT
70        tensor_shape { dim { size: 3 } }
71        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
72        """, t)
73    a = tensor_util.MakeNdarray(t)
74    self.assertEqual(np.float32, a.dtype)
75    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
76
77  def testFloatTyped(self):
78    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
79    if sys.byteorder == "big":
80      self.assertProtoEquals(r"""
81        dtype: DT_FLOAT
82        tensor_shape { dim { size: 3 } }
83        tensor_content: "A \000\000A\240\000\000A\360\000\000"
84        """, t)
85    else:
86      self.assertProtoEquals(r"""
87        dtype: DT_FLOAT
88        tensor_shape { dim { size: 3 } }
89        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
90        """, t)
91    a = tensor_util.MakeNdarray(t)
92    self.assertEqual(np.float32, a.dtype)
93    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
94
95  def testFloatTypeCoerce(self):
96    t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
97    if sys.byteorder == "big":
98      self.assertProtoEquals(r"""
99        dtype: DT_FLOAT
100        tensor_shape { dim { size: 3 } }
101        tensor_content: "A \000\000A\240\000\000A\360\000\000"
102        """, t)
103    else:
104      self.assertProtoEquals(r"""
105        dtype: DT_FLOAT
106        tensor_shape { dim { size: 3 } }
107        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
108        """, t)
109    a = tensor_util.MakeNdarray(t)
110    self.assertEqual(np.float32, a.dtype)
111    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
112
113  def testFloatTypeCoerceNdarray(self):
114    arr = np.asarray([10, 20, 30], dtype="int")
115    t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
116    if sys.byteorder == "big":
117      self.assertProtoEquals(r"""
118        dtype: DT_FLOAT
119        tensor_shape { dim { size: 3 } }
120        tensor_content: "A \000\000A\240\000\000A\360\000\000"
121        """, t)
122    else:
123      self.assertProtoEquals(r"""
124        dtype: DT_FLOAT
125        tensor_shape { dim { size: 3 } }
126        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
127        """, t)
128    a = tensor_util.MakeNdarray(t)
129    self.assertEqual(np.float32, a.dtype)
130    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
131
132  def testFloatSizes(self):
133    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
134    if sys.byteorder == "big":
135      self.assertProtoEquals(r"""
136        dtype: DT_FLOAT
137        tensor_shape { dim { size: 1 } dim { size: 3 } }
138        tensor_content: "A \000\000A\240\000\000A\360\000\000"
139        """, t)
140    else:
141      self.assertProtoEquals(r"""
142        dtype: DT_FLOAT
143        tensor_shape { dim { size: 1 } dim { size: 3 } }
144        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
145        """, t)
146    a = tensor_util.MakeNdarray(t)
147    self.assertEqual(np.float32, a.dtype)
148    self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a)
149
150  def testFloatSizes2(self):
151    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
152    if sys.byteorder == "big":
153      self.assertProtoEquals(r"""
154        dtype: DT_FLOAT
155        tensor_shape { dim { size: 3 } dim { size: 1 } }
156        tensor_content: "A \000\000A\240\000\000A\360\000\000"
157        """, t)
158    else:
159      self.assertProtoEquals(r"""
160        dtype: DT_FLOAT
161        tensor_shape { dim { size: 3 } dim { size: 1 } }
162        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
163        """, t)
164    a = tensor_util.MakeNdarray(t)
165    self.assertEqual(np.float32, a.dtype)
166    self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32), a)
167
168  def testFloatSizesLessValues(self):
169    t = tensor_util.make_tensor_proto(10.0, shape=[1, 3])
170    self.assertProtoEquals("""
171      dtype: DT_FLOAT
172      tensor_shape { dim { size: 1 } dim { size: 3 } }
173      float_val: 10.0
174      """, t)
175    # No conversion to Ndarray for this one: not enough values.
176
177  def testFloatNpArrayFloat64(self):
178    t = tensor_util.make_tensor_proto(
179        np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
180    if sys.byteorder == "big":
181      self.assertProtoEquals(r"""
182        dtype: DT_DOUBLE
183        tensor_shape { dim { size: 1 } dim { size: 3 } }
184        tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000"
185        """, t)
186    else:
187      self.assertProtoEquals(r"""
188        dtype: DT_DOUBLE
189        tensor_shape { dim { size: 1 } dim { size: 3 } }
190        tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
191        """, t)
192    a = tensor_util.MakeNdarray(t)
193    self.assertEqual(np.float64, a.dtype)
194    self.assertAllClose(
195        np.array([[10.0, 20.0, 30.0]], dtype=np.float64),
196        tensor_util.MakeNdarray(t))
197
198  def testFloatTypesWithImplicitRepeat(self):
199    for dtype, nptype in [(dtypes.float32, np.float32),
200                          (dtypes.float64, np.float64)]:
201      t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype)
202      a = tensor_util.MakeNdarray(t)
203      self.assertAllClose(
204          np.array(
205              [[10.0, 10.0, 10.0, 10.0],
206               [10.0, 10.0, 10.0, 10.0],
207               [10.0, 10.0, 10.0, 10.0]],
208              dtype=nptype),
209          a)
210
211  def testFloatMutateArray(self):
212    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
213    a = tensor_util.MakeNdarray(t)
214    a[0] = 5.0
215    self.assertEqual(np.float32, a.dtype)
216    self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a)
217    if sys.byteorder == "big":
218      self.assertProtoEquals(r"""
219        dtype: DT_FLOAT
220        tensor_shape { dim { size: 3 } }
221        tensor_content: "A \000\000A\240\000\000A\360\000\000"
222        """, t)
223    else:
224      self.assertProtoEquals(r"""
225        dtype: DT_FLOAT
226        tensor_shape { dim { size: 3 } }
227        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
228        """, t)
229
230  def testHalf(self):
231    t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16))
232    self.assertProtoEquals(
233        """
234      dtype: DT_HALF
235      tensor_shape { dim { size: 2 } }
236      tensor_content: "\000I\000M"
237      """, t)
238
239    a = tensor_util.MakeNdarray(t)
240    self.assertEqual(np.float16, a.dtype)
241    self.assertAllClose(np.array([10.0, 20.0], dtype=np.float16), a)
242
243  def testBfloat16(self):
244    test_type = dtypes.bfloat16.as_numpy_dtype
245    t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type))
246    # 10.0: 16672 = 010000010(130) 0100000: (1+0/2+1/4) * 2^(130-127)
247    # 20.0: 16800 = 010000011(131) 0100000: (1+0/2+1/4) * 2^(131-127)
248    self.assertProtoEquals("""
249      dtype: DT_BFLOAT16
250      tensor_shape {
251        dim {
252          size: 2
253        }
254      }
255      half_val: 16672
256      half_val: 16800
257      """, t)
258
259    a = tensor_util.MakeNdarray(t)
260    self.assertEqual(test_type, a.dtype)
261    self.assertAllClose(np.array([10.0, 20.0], dtype=test_type), a)
262
263  def testInt(self):
264    t = tensor_util.make_tensor_proto(10)
265    self.assertProtoEquals("""
266      dtype: DT_INT32
267      tensor_shape {}
268      int_val: 10
269      """, t)
270    a = tensor_util.MakeNdarray(t)
271    self.assertEqual(np.int32, a.dtype)
272    self.assertAllClose(np.array(10, dtype=np.int32), a)
273
274  def testLargeInt(self):
275    value = np.iinfo(np.int64).max
276    t = tensor_util.make_tensor_proto(value)
277    self.assertProtoEquals("""
278      dtype: DT_INT64
279      tensor_shape {}
280      int64_val: %d
281      """ % value, t)
282    a = tensor_util.MakeNdarray(t)
283    self.assertEqual(np.int64, a.dtype)
284    self.assertAllClose(np.array(value, dtype=np.int64), a)
285
286  def testLargeNegativeInt(self):
287    # We don't use the min np.int64 value here
288    # because it breaks np.abs().
289    #
290    # np.iinfo(np.int64).min = -9223372036854775808
291    # np.iinfo(np.int64).max = 9223372036854775807
292    # np.abs(-9223372036854775808) = -9223372036854775808
293    value = np.iinfo(np.int64).min + 1
294    t = tensor_util.make_tensor_proto(value)
295    self.assertProtoEquals("""
296      dtype: DT_INT64
297      tensor_shape {}
298      int64_val: %d
299      """ % value, t)
300    a = tensor_util.MakeNdarray(t)
301    self.assertEqual(np.int64, a.dtype)
302    self.assertAllClose(np.array(value, dtype=np.int64), a)
303
304  def testIntNDefaultType(self):
305    t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
306    if sys.byteorder == "big":
307      self.assertProtoEquals(r"""
308        dtype: DT_INT32
309        tensor_shape { dim { size: 2 } dim { size: 2 } }
310        tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000("
311        """, t)
312    else:
313      self.assertProtoEquals(r"""
314        dtype: DT_INT32
315        tensor_shape { dim { size: 2 } dim { size: 2 } }
316        tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
317        """, t)
318    a = tensor_util.MakeNdarray(t)
319    self.assertEqual(np.int32, a.dtype)
320    self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
321
322  @parameterized.named_parameters(
323      ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
324      ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
325      ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
326      ("_uint32", dtypes.uint32, np.uint32),
327      ("_uint64", dtypes.uint64, np.uint64))
328  def testIntTypes(self, dtype, nptype):
329    # Test with array.
330    t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
331    self.assertEqual(dtype, t.dtype)
332    self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
333    a = tensor_util.MakeNdarray(t)
334    self.assertEqual(nptype, a.dtype)
335    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
336    # Test with ndarray.
337    t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
338    self.assertEqual(dtype, t.dtype)
339    self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
340    a = tensor_util.MakeNdarray(t)
341    self.assertEqual(nptype, a.dtype)
342    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
343
344  @parameterized.named_parameters(
345      ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
346      ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
347      ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
348      ("_uint32", dtypes.uint32, np.uint32),
349      ("_uint64", dtypes.uint64, np.uint64))
350  def testIntTypesWithImplicitRepeat(self, dtype, nptype):
351    self.assertAllEqual(
352        np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
353                 dtype=nptype),
354        tensor_util.MakeNdarray(
355            tensor_util.make_tensor_proto([10, 11, 12],
356                                          shape=[3, 4],
357                                          dtype=dtype)))
358
359  def testIntMixedWithDimension(self):
360    # Github issue: 11974
361    dtype = dtypes.int32
362    nptype = np.int32
363    t = tensor_util.make_tensor_proto(
364        [10, tensor_shape.Dimension(20), 30], dtype=dtype)
365    self.assertEqual(dtype, t.dtype)
366    a = tensor_util.MakeNdarray(t)
367    self.assertEqual(nptype, a.dtype)
368    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
369
370  @parameterized.named_parameters(
371      ("_int64", dtypes.int64, np.int64, "DT_INT64", "int64_val"),
372      ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64", "uint64_val"))
373  def testLong(self, dtype, nptype, proto_dtype, proto_value_name):
374    t = tensor_util.make_tensor_proto(10, dtype=dtype)
375    self.assertProtoEquals(
376        """
377      dtype: %s
378      tensor_shape {}
379      %s: 10
380    """ % (proto_dtype, proto_value_name), t)
381    a = tensor_util.MakeNdarray(t)
382    self.assertEqual(nptype, a.dtype)
383    self.assertAllClose(np.array(10, dtype=nptype), a)
384
385  @parameterized.named_parameters(
386      ("_int64", dtypes.int64, np.int64, "DT_INT64"),
387      ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64"))
388  def testLongN(self, dtype, nptype, proto_dtype):
389    t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], dtype=dtype)
390    if sys.byteorder == "big":
391      # pylint: disable=line-too-long
392      self.assertProtoEquals(
393          r"""
394        dtype: %s
395        tensor_shape { dim { size: 1 } dim { size: 3 } }
396        tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
397        """ % proto_dtype, t)
398      # pylint: enable=line-too-long
399    else:
400      # pylint: disable=line-too-long
401      self.assertProtoEquals(
402          r"""
403        dtype: %s
404        tensor_shape { dim { size: 1 } dim { size: 3 } }
405        tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
406        """ % proto_dtype, t)
407      # pylint: enable=line-too-long
408    a = tensor_util.MakeNdarray(t)
409    self.assertEqual(nptype, a.dtype)
410    self.assertAllClose(np.array([[10, 20, 30]], dtype=nptype), a)
411
412  @parameterized.named_parameters(("_int64", np.int64, "DT_INT64"),
413                                  ("_uint64", np.uint64, "DT_UINT64"))
414  def testLongNpArray(self, nptype, proto_dtype):
415    t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
416    if sys.byteorder == "big":
417      # pylint: disable=line-too-long
418      self.assertProtoEquals(
419          r"""
420        dtype: %s
421        tensor_shape { dim { size: 3 } }
422        tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
423        """ % proto_dtype, t)
424      # pylint: enable=line-too-long
425    else:
426      # pylint: disable=line-too-long
427      self.assertProtoEquals(
428          r"""
429        dtype: %s
430        tensor_shape { dim { size: 3 } }
431        tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
432        """ % proto_dtype, t)
433      # pylint: enable=line-too-long
434    a = tensor_util.MakeNdarray(t)
435    self.assertEqual(nptype, a.dtype)
436    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
437
438  def testQuantizedTypes(self):
439    # Test with array.
440    data = [(21,), (22,), (23,)]
441
442    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32)
443    if sys.byteorder == "big":
444      self.assertProtoEquals(r"""
445        dtype: DT_QINT32
446        tensor_shape { dim { size: 3 } }
447        tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027"
448        """, t)
449    else:
450      self.assertProtoEquals(r"""
451        dtype: DT_QINT32
452        tensor_shape { dim { size: 3 } }
453        tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000"
454        """, t)
455    a = tensor_util.MakeNdarray(t)
456    self.assertEqual(dtypes.qint32.as_numpy_dtype, a.dtype)
457    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
458
459    t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8)
460    self.assertProtoEquals(r"""
461      dtype: DT_QUINT8
462      tensor_shape { dim { size: 3 } }
463      tensor_content: "\025\026\027"
464      """, t)
465    a = tensor_util.MakeNdarray(t)
466    self.assertEqual(dtypes.quint8.as_numpy_dtype, a.dtype)
467    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
468
469    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8)
470    self.assertProtoEquals(r"""
471      dtype: DT_QINT8
472      tensor_shape { dim { size: 3 } }
473      tensor_content: "\025\026\027"
474      """, t)
475    a = tensor_util.MakeNdarray(t)
476    self.assertEqual(dtypes.qint8.as_numpy_dtype, a.dtype)
477    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
478
479    t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16)
480    if sys.byteorder == "big":
481      self.assertProtoEquals(r"""
482        dtype: DT_QUINT16
483        tensor_shape { dim { size: 3 } }
484        tensor_content: "\000\025\000\026\000\027"
485        """, t)
486    else:
487      self.assertProtoEquals(r"""
488        dtype: DT_QUINT16
489        tensor_shape { dim { size: 3 } }
490        tensor_content: "\025\000\026\000\027\000"
491        """, t)
492    a = tensor_util.MakeNdarray(t)
493    self.assertEqual(dtypes.quint16.as_numpy_dtype, a.dtype)
494    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
495
496    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16)
497    if sys.byteorder == "big":
498      self.assertProtoEquals(r"""
499        dtype: DT_QINT16
500        tensor_shape { dim { size: 3 } }
501        tensor_content: "\000\025\000\026\000\027"
502        """, t)
503    else:
504      self.assertProtoEquals(r"""
505        dtype: DT_QINT16
506        tensor_shape { dim { size: 3 } }
507        tensor_content: "\025\000\026\000\027\000"
508        """, t)
509    a = tensor_util.MakeNdarray(t)
510    self.assertEqual(dtypes.qint16.as_numpy_dtype, a.dtype)
511    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
512
513  def testString(self):
514    t = tensor_util.make_tensor_proto("foo")
515    self.assertProtoEquals("""
516      dtype: DT_STRING
517      tensor_shape {}
518      string_val: "foo"
519      """, t)
520    a = tensor_util.MakeNdarray(t)
521    self.assertEqual(np.object, a.dtype)
522    self.assertEqual([b"foo"], a)
523
524  def testStringWithImplicitRepeat(self):
525    t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4])
526    a = tensor_util.MakeNdarray(t)
527    self.assertAllEqual(
528        np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"],
529                  [b"g", b"g", b"g", b"g"]],
530                 dtype=np.object), a)
531
532  def testStringN(self):
533    t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3])
534    self.assertProtoEquals("""
535      dtype: DT_STRING
536      tensor_shape { dim { size: 1 } dim { size: 3 } }
537      string_val: "foo"
538      string_val: "bar"
539      string_val: "baz"
540      """, t)
541    a = tensor_util.MakeNdarray(t)
542    self.assertEqual(np.object, a.dtype)
543    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
544
545  def testStringNpArray(self):
546    t = tensor_util.make_tensor_proto(
547        np.array([[b"a", b"ab"], [b"abc", b"abcd"]]))
548    self.assertProtoEquals("""
549      dtype: DT_STRING
550      tensor_shape { dim { size: 2 } dim { size: 2 } }
551      string_val: "a"
552      string_val: "ab"
553      string_val: "abc"
554      string_val: "abcd"
555      """, t)
556    a = tensor_util.MakeNdarray(t)
557    self.assertEqual(np.object, a.dtype)
558    self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a)
559
560  def testArrayMethod(self):
561
562    class Wrapper(object):
563
564      def __array__(self):
565        return np.array([b"foo", b"bar", b"baz"])
566
567    t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3])
568    self.assertProtoEquals("""
569      dtype: DT_STRING
570      tensor_shape { dim { size: 1 } dim { size: 3 } }
571      string_val: "foo"
572      string_val: "bar"
573      string_val: "baz"
574      """, t)
575    a = tensor_util.MakeNdarray(t)
576    self.assertEqual(np.object, a.dtype)
577    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
578
579  def testArrayInterface(self):
580
581    class Wrapper(object):
582
583      @property
584      def __array_interface__(self):
585        return np.array([b"foo", b"bar", b"baz"]).__array_interface__
586
587    t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3])
588    self.assertProtoEquals("""
589      dtype: DT_STRING
590      tensor_shape { dim { size: 1 } dim { size: 3 } }
591      string_val: "foo"
592      string_val: "bar"
593      string_val: "baz"
594      """, t)
595    a = tensor_util.MakeNdarray(t)
596    self.assertEqual(np.object, a.dtype)
597    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
598
599  def testStringTuple(self):
600    t = tensor_util.make_tensor_proto((b"a", b"ab", b"abc", b"abcd"))
601    self.assertProtoEquals("""
602      dtype: DT_STRING
603      tensor_shape { dim { size: 4 } }
604      string_val: "a"
605      string_val: "ab"
606      string_val: "abc"
607      string_val: "abcd"
608      """, t)
609    a = tensor_util.MakeNdarray(t)
610    self.assertEqual(np.object, a.dtype)
611    self.assertAllEqual(np.array((b"a", b"ab", b"abc", b"abcd")), a)
612
613  def testStringNestedTuple(self):
614    t = tensor_util.make_tensor_proto(((b"a", b"ab"), (b"abc", b"abcd")))
615    self.assertProtoEquals("""
616      dtype: DT_STRING
617      tensor_shape { dim { size: 2 } dim { size: 2 } }
618      string_val: "a"
619      string_val: "ab"
620      string_val: "abc"
621      string_val: "abcd"
622      """, t)
623    a = tensor_util.MakeNdarray(t)
624    self.assertEqual(np.object, a.dtype)
625    self.assertAllEqual(np.array(((b"a", b"ab"), (b"abc", b"abcd"))), a)
626
627  def testComplex64(self):
628    t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex64)
629    self.assertProtoEquals("""
630      dtype: DT_COMPLEX64
631      tensor_shape {}
632      scomplex_val: 1
633      scomplex_val: 2
634      """, t)
635    a = tensor_util.MakeNdarray(t)
636    self.assertEqual(np.complex64, a.dtype)
637    self.assertAllEqual(np.array(1 + 2j), a)
638
639  def testComplex128(self):
640    t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex128)
641    self.assertProtoEquals("""
642      dtype: DT_COMPLEX128
643      tensor_shape {}
644      dcomplex_val: 1
645      dcomplex_val: 2
646      """, t)
647    a = tensor_util.MakeNdarray(t)
648    self.assertEqual(np.complex128, a.dtype)
649    self.assertAllEqual(np.array(1 + 2j), a)
650
651  def testComplexWithImplicitRepeat(self):
652    for dtype, np_dtype in [(dtypes.complex64, np.complex64),
653                            (dtypes.complex128, np.complex128)]:
654      t = tensor_util.make_tensor_proto((1 + 1j), shape=[3, 4], dtype=dtype)
655      a = tensor_util.MakeNdarray(t)
656      self.assertAllClose(
657          np.array(
658              [[(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)],
659               [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)],
660               [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)]],
661              dtype=np_dtype),
662          a)
663
664  def testComplex64N(self):
665    t = tensor_util.make_tensor_proto(
666        [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex64)
667    self.assertProtoEquals("""
668      dtype: DT_COMPLEX64
669      tensor_shape { dim { size: 1 } dim { size: 3 } }
670      scomplex_val: 1
671      scomplex_val: 2
672      scomplex_val: 3
673      scomplex_val: 4
674      scomplex_val: 5
675      scomplex_val: 6
676      """, t)
677    a = tensor_util.MakeNdarray(t)
678    self.assertEqual(np.complex64, a.dtype)
679    self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a)
680
681  def testComplex128N(self):
682    t = tensor_util.make_tensor_proto(
683        [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex128)
684    self.assertProtoEquals("""
685      dtype: DT_COMPLEX128
686      tensor_shape { dim { size: 1 } dim { size: 3 } }
687      dcomplex_val: 1
688      dcomplex_val: 2
689      dcomplex_val: 3
690      dcomplex_val: 4
691      dcomplex_val: 5
692      dcomplex_val: 6
693      """, t)
694    a = tensor_util.MakeNdarray(t)
695    self.assertEqual(np.complex128, a.dtype)
696    self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a)
697
698  def testComplex64NpArray(self):
699    t = tensor_util.make_tensor_proto(
700        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]),
701        dtype=dtypes.complex64)
702    # scomplex_val are real_0, imag_0, real_1, imag_1, ...
703    self.assertProtoEquals("""
704      dtype: DT_COMPLEX64
705      tensor_shape { dim { size: 2 } dim { size: 2 } }
706      scomplex_val: 1
707      scomplex_val: 2
708      scomplex_val: 3
709      scomplex_val: 4
710      scomplex_val: 5
711      scomplex_val: 6
712      scomplex_val: 7
713      scomplex_val: 8
714      """, t)
715    a = tensor_util.MakeNdarray(t)
716    self.assertEqual(np.complex64, a.dtype)
717    self.assertAllEqual(
718        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a)
719
720  def testComplex128NpArray(self):
721    t = tensor_util.make_tensor_proto(
722        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]),
723        dtype=dtypes.complex128)
724    # scomplex_val are real_0, imag_0, real_1, imag_1, ...
725    self.assertProtoEquals("""
726      dtype: DT_COMPLEX128
727      tensor_shape { dim { size: 2 } dim { size: 2 } }
728      dcomplex_val: 1
729      dcomplex_val: 2
730      dcomplex_val: 3
731      dcomplex_val: 4
732      dcomplex_val: 5
733      dcomplex_val: 6
734      dcomplex_val: 7
735      dcomplex_val: 8
736      """, t)
737    a = tensor_util.MakeNdarray(t)
738    self.assertEqual(np.complex128, a.dtype)
739    self.assertAllEqual(
740        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a)
741
742  def testNestedNumpyArrayWithoutDType(self):
743    t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)])
744    a = tensor_util.MakeNdarray(t)
745    self.assertEqual(np.float32, a.dtype)
746    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
747
748  def testNestedNumpyArrayWithDType(self):
749    t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)],
750                                      dtype=dtypes.float32)
751    a = tensor_util.MakeNdarray(t)
752    self.assertEqual(np.float32, a.dtype)
753    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
754
755  def testUnsupportedDTypes(self):
756    with self.assertRaises(TypeError):
757      tensor_util.make_tensor_proto(np.array([1]), 0)
758    with self.assertRaises(TypeError):
759      tensor_util.make_tensor_proto(3, dtype=dtypes.qint8)
760    with self.assertRaises(TypeError):
761      tensor_util.make_tensor_proto([3], dtype=dtypes.qint8)
762
763    # Validate the helpful error message when trying to convert an
764    # unconvertible list as strings.
765    with self.assertRaisesRegex(TypeError, "Failed to convert object"):
766      tensor_util.make_tensor_proto([tensor_shape.Dimension(1)])
767
768  def testTensorShapeVerification(self):
769    array = np.array([[1], [2]])
770    correct_shape = (2, 1)
771    incorrect_shape = (1, 2)
772    tensor_util.make_tensor_proto(array, shape=correct_shape, verify_shape=True)
773    with self.assertRaises(TypeError):
774      tensor_util.make_tensor_proto(
775          array, shape=incorrect_shape, verify_shape=True)
776
777  def testShapeTooLarge(self):
778    with self.assertRaises(ValueError):
779      tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1])
780
781  def testLowRankSupported(self):
782    t = tensor_util.make_tensor_proto(np.array(7))
783    self.assertProtoEquals("""
784      dtype: DT_INT64
785      tensor_shape {}
786      int64_val: 7
787      """, t)
788
789  def testShapeEquals(self):
790    t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
791    self.assertTrue(tensor_util.ShapeEquals(t, [2, 2]))
792    self.assertTrue(tensor_util.ShapeEquals(t, (2, 2)))
793    self.assertTrue(
794        tensor_util.ShapeEquals(t, tensor_shape.as_shape([2, 2]).as_proto()))
795    self.assertFalse(tensor_util.ShapeEquals(t, [5, 3]))
796    self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
797    self.assertFalse(tensor_util.ShapeEquals(t, [4]))
798
799
800@test_util.run_all_in_graph_and_eager_modes
801class IsTensorTest(test.TestCase):
802
803  def testConstantTensor(self):
804    np_val = np.random.rand(3).astype(np.int32)
805    tf_val = constant_op.constant(np_val)
806    self.assertFalse(tensor_util.is_tf_type(np_val))
807    self.assertTrue(tensor_util.is_tf_type(tf_val))
808
809  def testRaggedTensor(self):
810    rt = ragged_factory_ops.constant([[1, 2], [3]])
811    rt_value = self.evaluate(rt)
812    self.assertTrue(tensor_util.is_tf_type(rt))
813    self.assertFalse(tensor_util.is_tf_type(rt_value))
814
815  def testSparseTensor(self):
816    st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10])
817    st_value = self.evaluate(st)
818    self.assertTrue(tensor_util.is_tf_type(st))
819    self.assertFalse(tensor_util.is_tf_type(st_value))
820
821  def testIndexedSlices(self):
822    x = indexed_slices.IndexedSlices(
823        constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30]))
824    x_value = indexed_slices.IndexedSlicesValue(
825        np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100]))
826    self.assertTrue(tensor_util.is_tf_type(x))
827    self.assertFalse(tensor_util.is_tf_type(x_value))
828
829  def testVariable(self):
830    v = variables.Variable([1, 2, 3])
831    self.assertTrue(tensor_util.is_tf_type(v))
832
833
834class ConstantValueTest(test.TestCase):
835
836  def testConstant(self):
837    np_val = np.random.rand(3, 4, 7).astype(np.float32)
838    tf_val = constant_op.constant(np_val)
839    self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
840
841    np_val = np.random.rand(3, 0, 7).astype(np.float32)
842    tf_val = constant_op.constant(np_val)
843    self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
844
845  def testUnknown(self):
846    with ops.Graph().as_default():
847      tf_val = gen_state_ops.variable(
848          shape=[3, 4, 7],
849          dtype=dtypes.float32,
850          name="tf_val",
851          container="",
852          shared_name="")
853      self.assertIs(None, tensor_util.constant_value(tf_val))
854
855  def testShape(self):
856    np_val = np.array([1, 2, 3], dtype=np.int32)
857    tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
858    c_val = tensor_util.constant_value(tf_val)
859    self.assertAllEqual(np_val, c_val)
860    self.assertEqual(np.int32, c_val.dtype)
861
862  def testFill(self):
863    np_val = np.array([-1, -1, -1], dtype=np.float32)
864    tf_val = array_ops.fill([3], constant_op.constant(-1.0))
865    c_val = tensor_util.constant_value(tf_val)
866    self.assertAllEqual(np_val, c_val)
867    self.assertEqual(np.float32, c_val.dtype)
868
869  def testSize(self):
870    tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
871    c_val = tensor_util.constant_value(tf_val)
872    self.assertEqual(6, c_val)
873
874  def testSizeOfScalar(self):
875    tf_val = array_ops.size(constant_op.constant(0.0))
876    c_val = tensor_util.constant_value(tf_val)
877    self.assertEqual(1, c_val)
878    self.assertIn(type(c_val), [np.ndarray, np.int32])
879
880  def testRank(self):
881    tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
882    c_val = tensor_util.constant_value(tf_val)
883
884    self.assertIn(type(c_val), [np.ndarray, np.int32])
885    self.assertEqual((), c_val.shape)
886    self.assertEqual(3, c_val)
887
888    # Repeat test using array_ops.rank_internal to avoid the optimization that
889    # happens in the rank function.
890    tf_val = array_ops.rank_internal(
891        constant_op.constant(
892            0.0, shape=[1, 2, 3]), optimize=False)
893    c_val = tensor_util.constant_value(tf_val)
894
895    self.assertIn(type(c_val), [np.ndarray, np.int32])
896    self.assertEqual((), c_val.shape)
897    self.assertEqual(3, c_val)
898    self.assertEqual([3], c_val)
899
900  def testCast(self):
901    np_val = np.random.rand(3, 4, 7).astype(np.float32)
902    tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
903    c_val = tensor_util.constant_value(tf_val)
904    self.assertAllClose(np_val.astype(np.float64), c_val)
905
906    np_val = np.random.rand(3, 0, 7).astype(np.float32)
907    tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
908    c_val = tensor_util.constant_value(tf_val)
909    self.assertAllClose(np_val.astype(np.float64), c_val)
910
911  def testConcat(self):
912    np_val = np.random.rand(3, 4, 7).astype(np.float32)
913    tf_val = array_ops.concat(
914        [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]], 0)
915    c_val = tensor_util.constant_value(tf_val)
916    self.assertAllClose(np_val, c_val)
917
918    # This test needs a placeholder which means we need to construct a graph.
919    with ops.Graph().as_default():
920      tf_val = array_ops.concat(
921          [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]],
922          array_ops.placeholder(dtypes.int32))
923      c_val = tensor_util.constant_value(tf_val)
924      self.assertIs(None, c_val)
925
926      tf_val = array_ops.concat([
927          np_val[0, :, :],
928          array_ops.placeholder(dtypes.float32), np_val[2, :, :]
929      ], 1)
930      c_val = tensor_util.constant_value(tf_val)
931      self.assertIs(None, c_val)
932
933  def testPack_Axis0(self):
934    inputs = [np.random.rand(4, 7) for _ in range(3)]
935    np_val = np.array(inputs)
936    tf_val = array_ops.stack(inputs)
937    c_val = tensor_util.constant_value(tf_val)
938    self.assertAllClose(np_val, c_val)
939
940    # This test needs a placeholder which means we need to construct a graph.
941    with ops.Graph().as_default():
942      tf_val = array_ops.stack(
943          [inputs[0],
944           array_ops.placeholder(dtypes.float32), inputs[2]])
945      c_val = tensor_util.constant_value(tf_val)
946      self.assertIs(None, c_val)
947
948  def testPack_Axis1(self):
949    # This test needs a placeholder which means we need to construct a graph.
950    with ops.Graph().as_default():
951      inputs = [np.random.rand(4, 7) for _ in range(3)]
952      tf_val = array_ops.stack(inputs, axis=1)
953      c_val = tensor_util.constant_value(tf_val)
954      self.assertIsNone(c_val)
955
956      tf_val = array_ops.stack(
957          [inputs[0],
958           array_ops.placeholder(dtypes.float32), inputs[2]], axis=1)
959      c_val = tensor_util.constant_value(tf_val)
960      self.assertIs(None, c_val)
961
962  def testPack_Partial_Axis0(self):
963    input_ = np.random.rand(4, 7)
964    # This test needs a placeholder which means we need to construct a graph.
965    with ops.Graph().as_default():
966      tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
967      c_val = tensor_util.constant_value(tf_val, partial=True)
968      self.assertAllClose(input_, c_val[0])
969      self.assertIsNone(c_val[1])
970
971  def testPack_Partial_Axis1(self):
972    input_ = np.random.rand(4, 7)
973    # This test needs a placeholder which means we need to construct a graph.
974    with ops.Graph().as_default():
975      tf_val = array_ops.stack(
976          [input_, array_ops.placeholder(dtypes.float32)], axis=1)
977      c_val = tensor_util.constant_value(tf_val, partial=True)
978      self.assertIsNone(c_val)
979
980  def testUnpack_Axis0(self):
981    inputs = np.random.rand(3, 4, 7)
982    tf_vals = array_ops.unstack(inputs)
983    c_vals = [tensor_util.constant_value(x) for x in tf_vals]
984    self.assertAllClose(inputs, c_vals)
985
986  def testUnpack_Partial_Axis0(self):
987    input_ = np.random.rand(4, 7)
988    # This test needs a placeholder which means we need to construct a graph.
989    with ops.Graph().as_default():
990      packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
991      tf_vals = array_ops.unstack(packed)
992      c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
993      self.assertAllClose(input_, c_vals[0])
994      self.assertIsNone(c_vals[1])
995
996  def testSplit_Axis0(self):
997    inputs = np.random.rand(6, 5, 7)
998    tf_vals = array_ops.split(inputs, 3)
999    c_vals = [tensor_util.constant_value(x) for x in tf_vals]
1000    self.assertAllClose(np.split(inputs, 3), c_vals)
1001
1002  def testSplit_Partial_Axis0(self):
1003    input_ = np.random.rand(4, 7)
1004    # This test needs a placeholder which means we need to construct a graph.
1005    with ops.Graph().as_default():
1006      placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7))
1007      # it'd be better to use concat here, but concat doesn't support partial
1008      packed = array_ops.stack([input_, placeholder])
1009      tf_vals = array_ops.split(packed, 2)
1010      c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
1011      self.assertAllClose(input_, c_vals[0][0])
1012      self.assertIsNone(c_vals[1][0])
1013
1014  def testEqual(self):
1015    # Scalar inputs.
1016    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(1))
1017    self.assertEqual(tensor_util.constant_value(tf_val), True)
1018
1019    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(0))
1020    self.assertEqual(tensor_util.constant_value(tf_val), False)
1021
1022    # Shaped inputs with broadcast semantics.
1023    tf_val = math_ops.equal(constant_op.constant([[0, 1]]),
1024                            constant_op.constant([[0], [1]]))
1025    c_val = tensor_util.constant_value(tf_val)
1026    self.assertAllEqual(c_val, [[True, False], [False, True]])
1027
1028  def testNotEqual(self):
1029    # Scalar inputs.
1030    tf_val = math_ops.not_equal(constant_op.constant(1),
1031                                constant_op.constant(1))
1032    self.assertEqual(tensor_util.constant_value(tf_val), False)
1033
1034    tf_val = math_ops.not_equal(constant_op.constant(1),
1035                                constant_op.constant(0))
1036    self.assertEqual(tensor_util.constant_value(tf_val), True)
1037
1038    # Shaped inputs with broadcast semantics.
1039    tf_val = math_ops.not_equal(constant_op.constant([[0, 1]]),
1040                                constant_op.constant([[0], [1]]))
1041    c_val = tensor_util.constant_value(tf_val)
1042    self.assertAllEqual(c_val, [[False, True], [True, False]])
1043
1044  def testStopGradient(self):
1045    input_ = np.random.rand(4, 7)
1046    tf_val = array_ops.stop_gradient(input_)
1047    c_val = tensor_util.constant_value(tf_val)
1048    self.assertAllEqual(input_, c_val)
1049
1050  def testIdentity(self):
1051    input_ = np.random.rand(4, 7)
1052    tf_val = array_ops.identity(input_)
1053    c_val = tensor_util.constant_value(tf_val)
1054    self.assertAllEqual(input_, c_val)
1055
1056  def testLiteral(self):
1057    x = "hi"
1058    self.assertIs(x, tensor_util.constant_value(x))
1059
1060  def testNumpyNdarray(self):
1061    np_val = np.random.rand(3, 4, 7).astype(np.float32)
1062    self.assertIs(np_val, tensor_util.constant_value(np_val))
1063
1064  def testVariable(self):
1065    var = variables.Variable(1.0, name="variable_node")
1066    self.assertIsNone(tensor_util.constant_value(var))
1067
1068  def testVariableV1(self):
1069    var = variables.VariableV1(1.0, name="variable_node")
1070    self.assertIsNone(tensor_util.constant_value(var))
1071
1072
1073class ConstantValueAsShapeTest(test.TestCase):
1074
1075  @test_util.run_in_graph_and_eager_modes
1076  def testConstant(self):
1077    np_val = np.random.rand(3).astype(np.int32)
1078    tf_val = constant_op.constant(np_val)
1079    self.assertEqual(
1080        tensor_shape.TensorShape(np_val),
1081        tensor_util.constant_value_as_shape(tf_val))
1082
1083    tf_val = constant_op.constant([], dtype=dtypes.int32)
1084    self.assertEqual(
1085        tensor_shape.TensorShape([]),
1086        tensor_util.constant_value_as_shape(tf_val))
1087
1088  @test_util.run_in_graph_and_eager_modes
1089  def testCast(self):
1090    tf_val = math_ops.cast(
1091        array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])),
1092        dtypes.int64)
1093    c_val = tensor_util.constant_value_as_shape(tf_val)
1094    self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val)
1095
1096  @test_util.run_in_graph_and_eager_modes
1097  def testCastWithUnknown(self):
1098    tf_val = math_ops.cast(constant_op.constant([-1, 1, -1]), dtypes.int64)
1099    c_val = tensor_util.constant_value_as_shape(tf_val)
1100    self.assertEqual([None, 1, None], c_val.as_list())
1101
1102  @test_util.run_in_graph_and_eager_modes
1103  def testShape(self):
1104    tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
1105    c_val = tensor_util.constant_value_as_shape(tf_val)
1106    self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val)
1107
1108  @test_util.run_in_graph_and_eager_modes
1109  def testMinusOneBecomesNone(self):
1110    tf_val = constant_op.constant([-1, 1, -1], shape=[3])
1111    c_val = tensor_util.constant_value_as_shape(tf_val)
1112    self.assertEqual([None, 1, None], c_val.as_list())
1113
1114  def testPack(self):
1115    # This test needs a placeholder which means we need to construct a graph.
1116    with ops.Graph().as_default():
1117      tf_val = array_ops.stack(
1118          [constant_op.constant(16), 37,
1119           array_ops.placeholder(dtypes.int32)])
1120      c_val = tensor_util.constant_value_as_shape(tf_val)
1121      self.assertEqual([16, 37, None], c_val.as_list())
1122
1123  def testConcat(self):
1124    # This test needs a placeholder which means we need to construct a graph.
1125    with ops.Graph().as_default():
1126      tf_val = array_ops.concat(
1127          [[16, 37], array_ops.placeholder(dtypes.int32, shape=(2,))], 0)
1128      c_val = tensor_util.constant_value_as_shape(tf_val)
1129      self.assertEqual([16, 37, None, None], c_val.as_list())
1130
1131      tf_val = array_ops.concat(
1132          [[16, 37],
1133           array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
1134      c_val = tensor_util.constant_value_as_shape(tf_val)
1135      self.assertEqual([16, 37, None, 48], c_val.as_list())
1136
1137  def testSlice(self):
1138    # This test needs a placeholder which means we need to construct a graph.
1139    with ops.Graph().as_default():
1140      tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
1141      c_val = tensor_util.constant_value_as_shape(tf_val)
1142      self.assertEqual([None, None], c_val.as_list())
1143
1144    # begin:end
1145    tf_val = constant_op.constant([10, 20, 30])[1:3]
1146    c_val = tensor_util.constant_value_as_shape(tf_val)
1147    self.assertEqual([20, 30], c_val.as_list())
1148
1149    # begin:end:stride
1150    tf_val = array_ops.strided_slice(
1151        constant_op.constant([10, 20, 30]), [1], [3], strides=[2])
1152    c_val = tensor_util.constant_value_as_shape(tf_val)
1153    self.assertEqual([20], c_val.as_list())
1154
1155    # [1, 2, 16, 37, None, 48]
1156    # This test needs a placeholder which means we need to construct a graph.
1157    with ops.Graph().as_default():
1158      tf_val_orig = array_ops.concat(
1159          [[1, 2, 16, 37],
1160           array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
1161
1162      # begin: no end
1163      tf_val = tf_val_orig[2:]
1164      c_val = tensor_util.constant_value_as_shape(tf_val)
1165      self.assertEqual([16, 37, None, 48], c_val.as_list())
1166
1167      # begin::negative slice
1168      tf_val = tf_val_orig[2::-1]
1169      c_val = tensor_util.constant_value_as_shape(tf_val)
1170      self.assertEqual([16, 2, 1], c_val.as_list())
1171
1172      # :end:negative slice
1173      tf_val = tf_val_orig[:1:-2]
1174      c_val = tensor_util.constant_value_as_shape(tf_val)
1175      self.assertEqual([48, 37], c_val.as_list())
1176
1177      # begin:end:negative slice
1178      tf_val = tf_val_orig[3:1:-1]
1179      c_val = tensor_util.constant_value_as_shape(tf_val)
1180      self.assertEqual([37, 16], c_val.as_list())
1181
1182      # begin:negative end:slice
1183      tf_val = tf_val_orig[1:-3:1]
1184      c_val = tensor_util.constant_value_as_shape(tf_val)
1185      self.assertEqual([2, 16], c_val.as_list())
1186
1187      # negative begin::slice
1188      tf_val = tf_val_orig[-3::1]
1189      c_val = tensor_util.constant_value_as_shape(tf_val)
1190      self.assertEqual([37, None, 48], c_val.as_list())
1191
1192      # negative begin::negative slice
1193      tf_val = tf_val_orig[-3::-1]
1194      c_val = tensor_util.constant_value_as_shape(tf_val)
1195      self.assertEqual([37, 16, 2, 1], c_val.as_list())
1196
1197      # negative begin:negative end:negative slice
1198      tf_val = tf_val_orig[-3:-5:-1]
1199      c_val = tensor_util.constant_value_as_shape(tf_val)
1200      self.assertEqual([37, 16], c_val.as_list())
1201
1202      # Do not support shape inference for additional arguments
1203      tf_val = constant_op.constant([10, 20, 30])[...]
1204      c_val = tensor_util.constant_value_as_shape(tf_val)
1205      self.assertEqual([None, None, None], c_val.as_list())
1206
1207      # Do not support shape inference for tensor slices.
1208      tf_val = constant_op.constant(
1209          [10, 20, 30])[array_ops.placeholder(dtypes.int32, shape=()):]
1210      c_val = tensor_util.constant_value_as_shape(tf_val)
1211      self.assertEqual(tensor_shape.unknown_shape(), c_val)
1212
1213      # Do not support shape inference for higher rank
1214      with self.assertRaises(ValueError):
1215        tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
1216        c_val = tensor_util.constant_value_as_shape(tf_val)
1217
1218
1219class MaybeSetStaticShapeTest(test.TestCase):
1220
1221  @contextlib.contextmanager
1222  def disableSetStaticShape(self):
1223    flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE
1224    tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False
1225    try:
1226      yield
1227    finally:
1228      tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
1229
1230  def testMaybeSetStaticShape(self):
1231    shape = constant_op.constant([2, 5], dtype=dtypes.int32)
1232
1233    def reshape():
1234      v = array_ops.zeros([10])
1235      return array_ops.reshape(v, shape)
1236    # This test needs a placeholder which means we need to construct a graph.
1237    with ops.Graph().as_default():
1238      with self.disableSetStaticShape():
1239        graph_without_shape_propagation = func_graph.func_graph_from_py_func(
1240            "without_shape_propagation", reshape, [], {})
1241      graph_with_shape_propagation = func_graph.func_graph_from_py_func(
1242          "with_shape_propagation", reshape, [], {})
1243      self.assertCountEqual(
1244          [op.type for op in graph_without_shape_propagation.get_operations()],
1245          [op.type for op in graph_with_shape_propagation.get_operations()])
1246
1247  def testMaybeSetStaticShapeScalarShape(self):
1248
1249    def reshape():
1250      v = array_ops.placeholder(dtypes.float32)
1251      t = array_ops.reshape(v, [-1])
1252      return t
1253
1254    with self.disableSetStaticShape():
1255      graph_without_shape_propagation = func_graph.func_graph_from_py_func(
1256          "without_shape_propagation", reshape, [], {})
1257    graph_with_shape_propagation = func_graph.func_graph_from_py_func(
1258        "with_shape_propagation", reshape, [], {})
1259    self.assertCountEqual(
1260        [op.type for op in graph_without_shape_propagation.get_operations()],
1261        [op.type for op in graph_with_shape_propagation.get_operations()])
1262
1263
1264class ShapeTensorTest(test_util.TensorFlowTestCase):
1265
1266  @test_util.run_in_graph_and_eager_modes
1267  def testConversion(self):
1268    """Make sure fully known TensorShape objects convert to Tensors."""
1269    shape = tensor_shape.TensorShape([1, tensor_shape.Dimension(2)])
1270    shape_tensor = tensor_util.shape_tensor(shape)
1271    self.assertAllEqual((1, 2), shape_tensor)
1272
1273
1274if __name__ == "__main__":
1275  test.main()
1276