• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for tensor_util."""
16
17import contextlib
18import sys
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import func_graph
26from tensorflow.python.framework import indexed_slices
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import gen_state_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.ragged import ragged_factory_ops
37from tensorflow.python.platform import test
38
39
40@test_util.run_all_in_graph_and_eager_modes
41class TensorUtilTest(test.TestCase, parameterized.TestCase):
42
43  def testFloat(self):
44    value = 10.0
45    t = tensor_util.make_tensor_proto(value)
46    self.assertProtoEquals("""
47      dtype: DT_FLOAT
48      tensor_shape {}
49      float_val: %.1f
50      """ % value, t)
51    a = tensor_util.MakeNdarray(t)
52    self.assertEqual(np.float32, a.dtype)
53    self.assertAllClose(np.array(value, dtype=np.float32), a)
54
55  def testFloatN(self):
56    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
57    if sys.byteorder == "big":
58      self.assertProtoEquals(r"""
59        dtype: DT_FLOAT
60        tensor_shape { dim { size: 3 } }
61        tensor_content: "A \000\000A\240\000\000A\360\000\000"
62        """, t)
63    else:
64      self.assertProtoEquals(r"""
65        dtype: DT_FLOAT
66        tensor_shape { dim { size: 3 } }
67        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
68        """, t)
69    a = tensor_util.MakeNdarray(t)
70    self.assertEqual(np.float32, a.dtype)
71    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
72
73  def testFloatTyped(self):
74    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
75    if sys.byteorder == "big":
76      self.assertProtoEquals(r"""
77        dtype: DT_FLOAT
78        tensor_shape { dim { size: 3 } }
79        tensor_content: "A \000\000A\240\000\000A\360\000\000"
80        """, t)
81    else:
82      self.assertProtoEquals(r"""
83        dtype: DT_FLOAT
84        tensor_shape { dim { size: 3 } }
85        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
86        """, t)
87    a = tensor_util.MakeNdarray(t)
88    self.assertEqual(np.float32, a.dtype)
89    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
90
91  def testFloatTypeCoerce(self):
92    t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
93    if sys.byteorder == "big":
94      self.assertProtoEquals(r"""
95        dtype: DT_FLOAT
96        tensor_shape { dim { size: 3 } }
97        tensor_content: "A \000\000A\240\000\000A\360\000\000"
98        """, t)
99    else:
100      self.assertProtoEquals(r"""
101        dtype: DT_FLOAT
102        tensor_shape { dim { size: 3 } }
103        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
104        """, t)
105    a = tensor_util.MakeNdarray(t)
106    self.assertEqual(np.float32, a.dtype)
107    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
108
109  def testFloatTypeCoerceNdarray(self):
110    arr = np.asarray([10, 20, 30], dtype="int")
111    t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
112    if sys.byteorder == "big":
113      self.assertProtoEquals(r"""
114        dtype: DT_FLOAT
115        tensor_shape { dim { size: 3 } }
116        tensor_content: "A \000\000A\240\000\000A\360\000\000"
117        """, t)
118    else:
119      self.assertProtoEquals(r"""
120        dtype: DT_FLOAT
121        tensor_shape { dim { size: 3 } }
122        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
123        """, t)
124    a = tensor_util.MakeNdarray(t)
125    self.assertEqual(np.float32, a.dtype)
126    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
127
128  def testFloatSizes(self):
129    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
130    if sys.byteorder == "big":
131      self.assertProtoEquals(r"""
132        dtype: DT_FLOAT
133        tensor_shape { dim { size: 1 } dim { size: 3 } }
134        tensor_content: "A \000\000A\240\000\000A\360\000\000"
135        """, t)
136    else:
137      self.assertProtoEquals(r"""
138        dtype: DT_FLOAT
139        tensor_shape { dim { size: 1 } dim { size: 3 } }
140        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
141        """, t)
142    a = tensor_util.MakeNdarray(t)
143    self.assertEqual(np.float32, a.dtype)
144    self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a)
145
146  def testFloatSizes2(self):
147    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
148    if sys.byteorder == "big":
149      self.assertProtoEquals(r"""
150        dtype: DT_FLOAT
151        tensor_shape { dim { size: 3 } dim { size: 1 } }
152        tensor_content: "A \000\000A\240\000\000A\360\000\000"
153        """, t)
154    else:
155      self.assertProtoEquals(r"""
156        dtype: DT_FLOAT
157        tensor_shape { dim { size: 3 } dim { size: 1 } }
158        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
159        """, t)
160    a = tensor_util.MakeNdarray(t)
161    self.assertEqual(np.float32, a.dtype)
162    self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32), a)
163
164  def testFloatSizesLessValues(self):
165    t = tensor_util.make_tensor_proto(10.0, shape=[1, 3])
166    self.assertProtoEquals("""
167      dtype: DT_FLOAT
168      tensor_shape { dim { size: 1 } dim { size: 3 } }
169      float_val: 10.0
170      """, t)
171    # No conversion to Ndarray for this one: not enough values.
172
173  def testFloatNpArrayFloat64(self):
174    t = tensor_util.make_tensor_proto(
175        np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
176    if sys.byteorder == "big":
177      self.assertProtoEquals(r"""
178        dtype: DT_DOUBLE
179        tensor_shape { dim { size: 1 } dim { size: 3 } }
180        tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000"
181        """, t)
182    else:
183      self.assertProtoEquals(r"""
184        dtype: DT_DOUBLE
185        tensor_shape { dim { size: 1 } dim { size: 3 } }
186        tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
187        """, t)
188    a = tensor_util.MakeNdarray(t)
189    self.assertEqual(np.float64, a.dtype)
190    self.assertAllClose(
191        np.array([[10.0, 20.0, 30.0]], dtype=np.float64),
192        tensor_util.MakeNdarray(t))
193
194  def testFloatTypesWithImplicitRepeat(self):
195    for dtype, nptype in [(dtypes.float32, np.float32),
196                          (dtypes.float64, np.float64)]:
197      t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype)
198      a = tensor_util.MakeNdarray(t)
199      self.assertAllClose(
200          np.array(
201              [[10.0, 10.0, 10.0, 10.0],
202               [10.0, 10.0, 10.0, 10.0],
203               [10.0, 10.0, 10.0, 10.0]],
204              dtype=nptype),
205          a)
206
207  def testFloatMutateArray(self):
208    t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
209    a = tensor_util.MakeNdarray(t)
210    a[0] = 5.0
211    self.assertEqual(np.float32, a.dtype)
212    self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a)
213    if sys.byteorder == "big":
214      self.assertProtoEquals(r"""
215        dtype: DT_FLOAT
216        tensor_shape { dim { size: 3 } }
217        tensor_content: "A \000\000A\240\000\000A\360\000\000"
218        """, t)
219    else:
220      self.assertProtoEquals(r"""
221        dtype: DT_FLOAT
222        tensor_shape { dim { size: 3 } }
223        tensor_content: "\000\000 A\000\000\240A\000\000\360A"
224        """, t)
225
226  def testHalf(self):
227    t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16))
228    self.assertProtoEquals(
229        """
230      dtype: DT_HALF
231      tensor_shape { dim { size: 2 } }
232      tensor_content: "\000I\000M"
233      """, t)
234
235    a = tensor_util.MakeNdarray(t)
236    self.assertEqual(np.float16, a.dtype)
237    self.assertAllClose(np.array([10.0, 20.0], dtype=np.float16), a)
238
239  def testBfloat16(self):
240    test_type = dtypes.bfloat16.as_numpy_dtype
241    t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type))
242    # 10.0: 16672 = 010000010(130) 0100000: (1+0/2+1/4) * 2^(130-127)
243    # 20.0: 16800 = 010000011(131) 0100000: (1+0/2+1/4) * 2^(131-127)
244    self.assertProtoEquals("""
245      dtype: DT_BFLOAT16
246      tensor_shape {
247        dim {
248          size: 2
249        }
250      }
251      half_val: 16672
252      half_val: 16800
253      """, t)
254
255    a = tensor_util.MakeNdarray(t)
256    self.assertEqual(test_type, a.dtype)
257    self.assertAllClose(np.array([10.0, 20.0], dtype=test_type), a)
258
259  def testInt(self):
260    t = tensor_util.make_tensor_proto(10)
261    self.assertProtoEquals("""
262      dtype: DT_INT32
263      tensor_shape {}
264      int_val: 10
265      """, t)
266    a = tensor_util.MakeNdarray(t)
267    self.assertEqual(np.int32, a.dtype)
268    self.assertAllClose(np.array(10, dtype=np.int32), a)
269
270  def testLargeInt(self):
271    value = np.iinfo(np.int64).max
272    t = tensor_util.make_tensor_proto(value)
273    self.assertProtoEquals("""
274      dtype: DT_INT64
275      tensor_shape {}
276      int64_val: %d
277      """ % value, t)
278    a = tensor_util.MakeNdarray(t)
279    self.assertEqual(np.int64, a.dtype)
280    self.assertAllClose(np.array(value, dtype=np.int64), a)
281
282  def testLargeNegativeInt(self):
283    # We don't use the min np.int64 value here
284    # because it breaks np.abs().
285    #
286    # np.iinfo(np.int64).min = -9223372036854775808
287    # np.iinfo(np.int64).max = 9223372036854775807
288    # np.abs(-9223372036854775808) = -9223372036854775808
289    value = np.iinfo(np.int64).min + 1
290    t = tensor_util.make_tensor_proto(value)
291    self.assertProtoEquals("""
292      dtype: DT_INT64
293      tensor_shape {}
294      int64_val: %d
295      """ % value, t)
296    a = tensor_util.MakeNdarray(t)
297    self.assertEqual(np.int64, a.dtype)
298    self.assertAllClose(np.array(value, dtype=np.int64), a)
299
300  def testIntNDefaultType(self):
301    t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
302    if sys.byteorder == "big":
303      self.assertProtoEquals(r"""
304        dtype: DT_INT32
305        tensor_shape { dim { size: 2 } dim { size: 2 } }
306        tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000("
307        """, t)
308    else:
309      self.assertProtoEquals(r"""
310        dtype: DT_INT32
311        tensor_shape { dim { size: 2 } dim { size: 2 } }
312        tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
313        """, t)
314    a = tensor_util.MakeNdarray(t)
315    self.assertEqual(np.int32, a.dtype)
316    self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a)
317
318  @parameterized.named_parameters(
319      ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
320      ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
321      ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
322      ("_uint32", dtypes.uint32, np.uint32),
323      ("_uint64", dtypes.uint64, np.uint64))
324  def testIntTypes(self, dtype, nptype):
325    # Test with array.
326    t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype)
327    self.assertEqual(dtype, t.dtype)
328    self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
329    a = tensor_util.MakeNdarray(t)
330    self.assertEqual(nptype, a.dtype)
331    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
332    # Test with ndarray.
333    t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
334    self.assertEqual(dtype, t.dtype)
335    self.assertProtoEquals("dim { size: 3 }", t.tensor_shape)
336    a = tensor_util.MakeNdarray(t)
337    self.assertEqual(nptype, a.dtype)
338    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
339
340  @parameterized.named_parameters(
341      ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16),
342      ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64),
343      ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16),
344      ("_uint32", dtypes.uint32, np.uint32),
345      ("_uint64", dtypes.uint64, np.uint64))
346  def testIntTypesWithImplicitRepeat(self, dtype, nptype):
347    self.assertAllEqual(
348        np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]],
349                 dtype=nptype),
350        tensor_util.MakeNdarray(
351            tensor_util.make_tensor_proto([10, 11, 12],
352                                          shape=[3, 4],
353                                          dtype=dtype)))
354
355  def testIntMixedWithDimension(self):
356    # Github issue: 11974
357    dtype = dtypes.int32
358    nptype = np.int32
359    t = tensor_util.make_tensor_proto(
360        [10, tensor_shape.Dimension(20), 30], dtype=dtype)
361    self.assertEqual(dtype, t.dtype)
362    a = tensor_util.MakeNdarray(t)
363    self.assertEqual(nptype, a.dtype)
364    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
365
366  @parameterized.named_parameters(
367      ("_int64", dtypes.int64, np.int64, "DT_INT64", "int64_val"),
368      ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64", "uint64_val"))
369  def testLong(self, dtype, nptype, proto_dtype, proto_value_name):
370    t = tensor_util.make_tensor_proto(10, dtype=dtype)
371    self.assertProtoEquals(
372        """
373      dtype: %s
374      tensor_shape {}
375      %s: 10
376    """ % (proto_dtype, proto_value_name), t)
377    a = tensor_util.MakeNdarray(t)
378    self.assertEqual(nptype, a.dtype)
379    self.assertAllClose(np.array(10, dtype=nptype), a)
380
381  @parameterized.named_parameters(
382      ("_int64", dtypes.int64, np.int64, "DT_INT64"),
383      ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64"))
384  def testLongN(self, dtype, nptype, proto_dtype):
385    t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], dtype=dtype)
386    if sys.byteorder == "big":
387      # pylint: disable=line-too-long
388      self.assertProtoEquals(
389          r"""
390        dtype: %s
391        tensor_shape { dim { size: 1 } dim { size: 3 } }
392        tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
393        """ % proto_dtype, t)
394      # pylint: enable=line-too-long
395    else:
396      # pylint: disable=line-too-long
397      self.assertProtoEquals(
398          r"""
399        dtype: %s
400        tensor_shape { dim { size: 1 } dim { size: 3 } }
401        tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
402        """ % proto_dtype, t)
403      # pylint: enable=line-too-long
404    a = tensor_util.MakeNdarray(t)
405    self.assertEqual(nptype, a.dtype)
406    self.assertAllClose(np.array([[10, 20, 30]], dtype=nptype), a)
407
408  @parameterized.named_parameters(("_int64", np.int64, "DT_INT64"),
409                                  ("_uint64", np.uint64, "DT_UINT64"))
410  def testLongNpArray(self, nptype, proto_dtype):
411    t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype))
412    if sys.byteorder == "big":
413      # pylint: disable=line-too-long
414      self.assertProtoEquals(
415          r"""
416        dtype: %s
417        tensor_shape { dim { size: 3 } }
418        tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
419        """ % proto_dtype, t)
420      # pylint: enable=line-too-long
421    else:
422      # pylint: disable=line-too-long
423      self.assertProtoEquals(
424          r"""
425        dtype: %s
426        tensor_shape { dim { size: 3 } }
427        tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
428        """ % proto_dtype, t)
429      # pylint: enable=line-too-long
430    a = tensor_util.MakeNdarray(t)
431    self.assertEqual(nptype, a.dtype)
432    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
433
434  def testQuantizedTypes(self):
435    # Test with array.
436    data = [(21,), (22,), (23,)]
437
438    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32)
439    if sys.byteorder == "big":
440      self.assertProtoEquals(r"""
441        dtype: DT_QINT32
442        tensor_shape { dim { size: 3 } }
443        tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027"
444        """, t)
445    else:
446      self.assertProtoEquals(r"""
447        dtype: DT_QINT32
448        tensor_shape { dim { size: 3 } }
449        tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000"
450        """, t)
451    a = tensor_util.MakeNdarray(t)
452    self.assertEqual(dtypes.qint32.as_numpy_dtype, a.dtype)
453    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
454
455    t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8)
456    self.assertProtoEquals(r"""
457      dtype: DT_QUINT8
458      tensor_shape { dim { size: 3 } }
459      tensor_content: "\025\026\027"
460      """, t)
461    a = tensor_util.MakeNdarray(t)
462    self.assertEqual(dtypes.quint8.as_numpy_dtype, a.dtype)
463    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
464
465    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8)
466    self.assertProtoEquals(r"""
467      dtype: DT_QINT8
468      tensor_shape { dim { size: 3 } }
469      tensor_content: "\025\026\027"
470      """, t)
471    a = tensor_util.MakeNdarray(t)
472    self.assertEqual(dtypes.qint8.as_numpy_dtype, a.dtype)
473    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
474
475    t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16)
476    if sys.byteorder == "big":
477      self.assertProtoEquals(r"""
478        dtype: DT_QUINT16
479        tensor_shape { dim { size: 3 } }
480        tensor_content: "\000\025\000\026\000\027"
481        """, t)
482    else:
483      self.assertProtoEquals(r"""
484        dtype: DT_QUINT16
485        tensor_shape { dim { size: 3 } }
486        tensor_content: "\025\000\026\000\027\000"
487        """, t)
488    a = tensor_util.MakeNdarray(t)
489    self.assertEqual(dtypes.quint16.as_numpy_dtype, a.dtype)
490    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
491
492    t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16)
493    if sys.byteorder == "big":
494      self.assertProtoEquals(r"""
495        dtype: DT_QINT16
496        tensor_shape { dim { size: 3 } }
497        tensor_content: "\000\025\000\026\000\027"
498        """, t)
499    else:
500      self.assertProtoEquals(r"""
501        dtype: DT_QINT16
502        tensor_shape { dim { size: 3 } }
503        tensor_content: "\025\000\026\000\027\000"
504        """, t)
505    a = tensor_util.MakeNdarray(t)
506    self.assertEqual(dtypes.qint16.as_numpy_dtype, a.dtype)
507    self.assertAllEqual(np.array(data, dtype=a.dtype), a)
508
509  def testString(self):
510    t = tensor_util.make_tensor_proto("foo")
511    self.assertProtoEquals("""
512      dtype: DT_STRING
513      tensor_shape {}
514      string_val: "foo"
515      """, t)
516    a = tensor_util.MakeNdarray(t)
517    self.assertEqual(np.object_, a.dtype)
518    self.assertEqual([b"foo"], a)
519
520  def testStringWithImplicitRepeat(self):
521    t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4])
522    a = tensor_util.MakeNdarray(t)
523    self.assertAllEqual(
524        np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"],
525                  [b"g", b"g", b"g", b"g"]],
526                 dtype=np.object_), a)
527
528  def testStringN(self):
529    t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3])
530    self.assertProtoEquals("""
531      dtype: DT_STRING
532      tensor_shape { dim { size: 1 } dim { size: 3 } }
533      string_val: "foo"
534      string_val: "bar"
535      string_val: "baz"
536      """, t)
537    a = tensor_util.MakeNdarray(t)
538    self.assertEqual(np.object_, a.dtype)
539    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
540
541  def testStringNpArray(self):
542    t = tensor_util.make_tensor_proto(
543        np.array([[b"a", b"ab"], [b"abc", b"abcd"]]))
544    self.assertProtoEquals("""
545      dtype: DT_STRING
546      tensor_shape { dim { size: 2 } dim { size: 2 } }
547      string_val: "a"
548      string_val: "ab"
549      string_val: "abc"
550      string_val: "abcd"
551      """, t)
552    a = tensor_util.MakeNdarray(t)
553    self.assertEqual(np.object_, a.dtype)
554    self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a)
555
556  def testArrayMethod(self):
557
558    class Wrapper(object):
559
560      def __array__(self, dtype=None):
561        del dtype
562        return np.array([b"foo", b"bar", b"baz"])
563
564    t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3])
565    self.assertProtoEquals("""
566      dtype: DT_STRING
567      tensor_shape { dim { size: 1 } dim { size: 3 } }
568      string_val: "foo"
569      string_val: "bar"
570      string_val: "baz"
571      """, t)
572    a = tensor_util.MakeNdarray(t)
573    self.assertEqual(np.object_, a.dtype)
574    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
575
576  def testArrayInterface(self):
577
578    class Wrapper(object):
579
580      def __init__(self):
581        self.a = np.array([b"foo", b"bar", b"baz"])
582
583      @property
584      def __array_interface__(self):
585        return self.a.__array_interface__
586
587    t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3])
588    self.assertProtoEquals("""
589      dtype: DT_STRING
590      tensor_shape { dim { size: 1 } dim { size: 3 } }
591      string_val: "foo"
592      string_val: "bar"
593      string_val: "baz"
594      """, t)
595    a = tensor_util.MakeNdarray(t)
596    self.assertEqual(np.object_, a.dtype)
597    self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a)
598
599  def testStringTuple(self):
600    t = tensor_util.make_tensor_proto((b"a", b"ab", b"abc", b"abcd"))
601    self.assertProtoEquals("""
602      dtype: DT_STRING
603      tensor_shape { dim { size: 4 } }
604      string_val: "a"
605      string_val: "ab"
606      string_val: "abc"
607      string_val: "abcd"
608      """, t)
609    a = tensor_util.MakeNdarray(t)
610    self.assertEqual(np.object_, a.dtype)
611    self.assertAllEqual(np.array((b"a", b"ab", b"abc", b"abcd")), a)
612
613  def testStringNestedTuple(self):
614    t = tensor_util.make_tensor_proto(((b"a", b"ab"), (b"abc", b"abcd")))
615    self.assertProtoEquals("""
616      dtype: DT_STRING
617      tensor_shape { dim { size: 2 } dim { size: 2 } }
618      string_val: "a"
619      string_val: "ab"
620      string_val: "abc"
621      string_val: "abcd"
622      """, t)
623    a = tensor_util.MakeNdarray(t)
624    self.assertEqual(np.object_, a.dtype)
625    self.assertAllEqual(np.array(((b"a", b"ab"), (b"abc", b"abcd"))), a)
626
627  def testComplex64(self):
628    t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex64)
629    self.assertProtoEquals("""
630      dtype: DT_COMPLEX64
631      tensor_shape {}
632      scomplex_val: 1
633      scomplex_val: 2
634      """, t)
635    a = tensor_util.MakeNdarray(t)
636    self.assertEqual(np.complex64, a.dtype)
637    self.assertAllEqual(np.array(1 + 2j), a)
638
639  def testComplex128(self):
640    t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex128)
641    self.assertProtoEquals("""
642      dtype: DT_COMPLEX128
643      tensor_shape {}
644      dcomplex_val: 1
645      dcomplex_val: 2
646      """, t)
647    a = tensor_util.MakeNdarray(t)
648    self.assertEqual(np.complex128, a.dtype)
649    self.assertAllEqual(np.array(1 + 2j), a)
650
651  def testComplexWithImplicitRepeat(self):
652    for dtype, np_dtype in [(dtypes.complex64, np.complex64),
653                            (dtypes.complex128, np.complex128)]:
654      t = tensor_util.make_tensor_proto((1 + 1j), shape=[3, 4], dtype=dtype)
655      a = tensor_util.MakeNdarray(t)
656      self.assertAllClose(
657          np.array(
658              [[(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)],
659               [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)],
660               [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)]],
661              dtype=np_dtype),
662          a)
663
664  def testComplex64N(self):
665    t = tensor_util.make_tensor_proto(
666        [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex64)
667    self.assertProtoEquals("""
668      dtype: DT_COMPLEX64
669      tensor_shape { dim { size: 1 } dim { size: 3 } }
670      scomplex_val: 1
671      scomplex_val: 2
672      scomplex_val: 3
673      scomplex_val: 4
674      scomplex_val: 5
675      scomplex_val: 6
676      """, t)
677    a = tensor_util.MakeNdarray(t)
678    self.assertEqual(np.complex64, a.dtype)
679    self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a)
680
681  def testComplex128N(self):
682    t = tensor_util.make_tensor_proto(
683        [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex128)
684    self.assertProtoEquals("""
685      dtype: DT_COMPLEX128
686      tensor_shape { dim { size: 1 } dim { size: 3 } }
687      dcomplex_val: 1
688      dcomplex_val: 2
689      dcomplex_val: 3
690      dcomplex_val: 4
691      dcomplex_val: 5
692      dcomplex_val: 6
693      """, t)
694    a = tensor_util.MakeNdarray(t)
695    self.assertEqual(np.complex128, a.dtype)
696    self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a)
697
698  def testComplex64NpArray(self):
699    t = tensor_util.make_tensor_proto(
700        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]),
701        dtype=dtypes.complex64)
702    # scomplex_val are real_0, imag_0, real_1, imag_1, ...
703    self.assertProtoEquals("""
704      dtype: DT_COMPLEX64
705      tensor_shape { dim { size: 2 } dim { size: 2 } }
706      scomplex_val: 1
707      scomplex_val: 2
708      scomplex_val: 3
709      scomplex_val: 4
710      scomplex_val: 5
711      scomplex_val: 6
712      scomplex_val: 7
713      scomplex_val: 8
714      """, t)
715    a = tensor_util.MakeNdarray(t)
716    self.assertEqual(np.complex64, a.dtype)
717    self.assertAllEqual(
718        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a)
719
720  def testComplex128NpArray(self):
721    t = tensor_util.make_tensor_proto(
722        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]),
723        dtype=dtypes.complex128)
724    # scomplex_val are real_0, imag_0, real_1, imag_1, ...
725    self.assertProtoEquals("""
726      dtype: DT_COMPLEX128
727      tensor_shape { dim { size: 2 } dim { size: 2 } }
728      dcomplex_val: 1
729      dcomplex_val: 2
730      dcomplex_val: 3
731      dcomplex_val: 4
732      dcomplex_val: 5
733      dcomplex_val: 6
734      dcomplex_val: 7
735      dcomplex_val: 8
736      """, t)
737    a = tensor_util.MakeNdarray(t)
738    self.assertEqual(np.complex128, a.dtype)
739    self.assertAllEqual(
740        np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a)
741
742  def testNestedNumpyArrayWithoutDType(self):
743    t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)])
744    a = tensor_util.MakeNdarray(t)
745    self.assertEqual(np.float32, a.dtype)
746    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
747
748  def testNestedNumpyArrayWithDType(self):
749    t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)],
750                                      dtype=dtypes.float32)
751    a = tensor_util.MakeNdarray(t)
752    self.assertEqual(np.float32, a.dtype)
753    self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a)
754
755  def testUnsupportedDTypes(self):
756    with self.assertRaises(TypeError):
757      tensor_util.make_tensor_proto(np.array([1]), 0)
758    with self.assertRaises(TypeError):
759      tensor_util.make_tensor_proto(3, dtype=dtypes.qint8)
760    with self.assertRaises(TypeError):
761      tensor_util.make_tensor_proto([3], dtype=dtypes.qint8)
762
763    # Validate the helpful error message when trying to convert an
764    # unconvertible list as strings.
765    with self.assertRaisesRegex(TypeError, "Failed to convert elements"):
766      tensor_util.make_tensor_proto([tensor_shape.Dimension(1)])
767
768  def testTensorShapeVerification(self):
769    array = np.array([[1], [2]])
770    correct_shape = (2, 1)
771    incorrect_shape = (1, 2)
772    tensor_util.make_tensor_proto(array, shape=correct_shape, verify_shape=True)
773    with self.assertRaises(TypeError):
774      tensor_util.make_tensor_proto(
775          array, shape=incorrect_shape, verify_shape=True)
776
777  def testShapeTooLarge(self):
778    with self.assertRaises(ValueError):
779      tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1])
780
781  def testLowRankSupported(self):
782    t = tensor_util.make_tensor_proto(np.array(7))
783    self.assertProtoEquals("""
784      dtype: DT_INT64
785      tensor_shape {}
786      int64_val: 7
787      """, t)
788
789  def testShapeEquals(self):
790    t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
791    self.assertTrue(tensor_util.ShapeEquals(t, [2, 2]))
792    self.assertTrue(tensor_util.ShapeEquals(t, (2, 2)))
793    self.assertTrue(
794        tensor_util.ShapeEquals(t, tensor_shape.as_shape([2, 2]).as_proto()))
795    self.assertFalse(tensor_util.ShapeEquals(t, [5, 3]))
796    self.assertFalse(tensor_util.ShapeEquals(t, [1, 4]))
797    self.assertFalse(tensor_util.ShapeEquals(t, [4]))
798
799
800@test_util.run_all_in_graph_and_eager_modes
801class IsTensorTest(test.TestCase):
802
803  def testConstantTensor(self):
804    np_val = np.random.rand(3).astype(np.int32)
805    tf_val = constant_op.constant(np_val)
806    self.assertFalse(tensor_util.is_tf_type(np_val))
807    self.assertTrue(tensor_util.is_tf_type(tf_val))
808
809  def testRaggedTensor(self):
810    rt = ragged_factory_ops.constant([[1, 2], [3]])
811    rt_value = self.evaluate(rt)
812    self.assertTrue(tensor_util.is_tf_type(rt))
813    self.assertFalse(tensor_util.is_tf_type(rt_value))
814
815  def testSparseTensor(self):
816    st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10])
817    st_value = self.evaluate(st)
818    self.assertTrue(tensor_util.is_tf_type(st))
819    self.assertFalse(tensor_util.is_tf_type(st_value))
820
821  def testIndexedSlices(self):
822    x = indexed_slices.IndexedSlices(
823        constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30]))
824    x_value = indexed_slices.IndexedSlicesValue(
825        np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100]))
826    self.assertTrue(tensor_util.is_tf_type(x))
827    self.assertFalse(tensor_util.is_tf_type(x_value))
828
829  def testVariable(self):
830    v = variables.Variable([1, 2, 3])
831    self.assertTrue(tensor_util.is_tf_type(v))
832
833
834class ConstantValueTest(test.TestCase):
835
836  def testConstant(self):
837    np_val = np.random.rand(3, 4, 7).astype(np.float32)
838    tf_val = constant_op.constant(np_val)
839    self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
840
841    np_val = np.random.rand(3, 0, 7).astype(np.float32)
842    tf_val = constant_op.constant(np_val)
843    self.assertAllClose(np_val, tensor_util.constant_value(tf_val))
844
845  def testUnknown(self):
846    with ops.Graph().as_default():
847      tf_val = gen_state_ops.variable(
848          shape=[3, 4, 7],
849          dtype=dtypes.float32,
850          name="tf_val",
851          container="",
852          shared_name="")
853      self.assertIs(None, tensor_util.constant_value(tf_val))
854
855  def testShape(self):
856    np_val = np.array([1, 2, 3], dtype=np.int32)
857    tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
858    c_val = tensor_util.constant_value(tf_val)
859    self.assertAllEqual(np_val, c_val)
860    self.assertEqual(np.int32, c_val.dtype)
861
862  def testFill(self):
863    np_val = np.array([-1, -1, -1], dtype=np.float32)
864    tf_val = array_ops.fill([3], constant_op.constant(-1.0))
865    c_val = tensor_util.constant_value(tf_val)
866    self.assertAllEqual(np_val, c_val)
867    self.assertEqual(np.float32, c_val.dtype)
868
869  def testSize(self):
870    tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3]))
871    c_val = tensor_util.constant_value(tf_val)
872    self.assertEqual(6, c_val)
873
874  def testSizeOfScalar(self):
875    tf_val = array_ops.size(constant_op.constant(0.0))
876    c_val = tensor_util.constant_value(tf_val)
877    self.assertEqual(1, c_val)
878    self.assertIn(type(c_val), [np.ndarray, np.int32])
879
880  def testRank(self):
881    tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3]))
882    c_val = tensor_util.constant_value(tf_val)
883
884    self.assertIn(type(c_val), [np.ndarray, np.int32])
885    self.assertEqual((), c_val.shape)
886    self.assertEqual(3, c_val)
887
888    # Repeat test using array_ops.rank_internal to avoid the optimization that
889    # happens in the rank function.
890    tf_val = array_ops.rank_internal(
891        constant_op.constant(
892            0.0, shape=[1, 2, 3]), optimize=False)
893    c_val = tensor_util.constant_value(tf_val)
894
895    self.assertIn(type(c_val), [np.ndarray, np.int32])
896    self.assertEqual((), c_val.shape)
897    self.assertEqual(3, c_val)
898    self.assertEqual([3], c_val)
899
900  def testCast(self):
901    np_val = np.random.rand(3, 4, 7).astype(np.float32)
902    tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
903    c_val = tensor_util.constant_value(tf_val)
904    self.assertAllClose(np_val.astype(np.float64), c_val)
905
906    np_val = np.random.rand(3, 0, 7).astype(np.float32)
907    tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64)
908    c_val = tensor_util.constant_value(tf_val)
909    self.assertAllClose(np_val.astype(np.float64), c_val)
910
911  def testConcat(self):
912    np_val = np.random.rand(3, 4, 7).astype(np.float32)
913    tf_val = array_ops.concat(
914        [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]], 0)
915    c_val = tensor_util.constant_value(tf_val)
916    self.assertAllClose(np_val, c_val)
917
918    # This test needs a placeholder which means we need to construct a graph.
919    with ops.Graph().as_default():
920      tf_val = array_ops.concat(
921          [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]],
922          array_ops.placeholder(dtypes.int32))
923      c_val = tensor_util.constant_value(tf_val)
924      self.assertIs(None, c_val)
925
926      tf_val = array_ops.concat([
927          np_val[0, :, :],
928          array_ops.placeholder(dtypes.float32), np_val[2, :, :]
929      ], 1)
930      c_val = tensor_util.constant_value(tf_val)
931      self.assertIs(None, c_val)
932
933  def testPack_Axis0(self):
934    inputs = [np.random.rand(4, 7) for _ in range(3)]
935    np_val = np.array(inputs)
936    tf_val = array_ops.stack(inputs)
937    c_val = tensor_util.constant_value(tf_val)
938    self.assertAllClose(np_val, c_val)
939
940    # This test needs a placeholder which means we need to construct a graph.
941    with ops.Graph().as_default():
942      tf_val = array_ops.stack(
943          [inputs[0],
944           array_ops.placeholder(dtypes.float32), inputs[2]])
945      c_val = tensor_util.constant_value(tf_val)
946      self.assertIs(None, c_val)
947
948  def testPack_Axis1(self):
949    # This test needs a placeholder which means we need to construct a graph.
950    with ops.Graph().as_default():
951      inputs = [np.random.rand(4, 7) for _ in range(3)]
952      tf_val = array_ops.stack(inputs, axis=1)
953      c_val = tensor_util.constant_value(tf_val)
954      self.assertIsNone(c_val)
955
956      tf_val = array_ops.stack(
957          [inputs[0],
958           array_ops.placeholder(dtypes.float32), inputs[2]], axis=1)
959      c_val = tensor_util.constant_value(tf_val)
960      self.assertIs(None, c_val)
961
962  def testPack_Partial_Axis0(self):
963    input_ = np.random.rand(4, 7)
964    # This test needs a placeholder which means we need to construct a graph.
965    with ops.Graph().as_default():
966      tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
967      c_val = tensor_util.constant_value(tf_val, partial=True)
968      self.assertAllClose(input_, c_val[0])
969      self.assertIsNone(c_val[1])
970
971  def testPack_Partial_Axis1(self):
972    input_ = np.random.rand(4, 7)
973    # This test needs a placeholder which means we need to construct a graph.
974    with ops.Graph().as_default():
975      tf_val = array_ops.stack(
976          [input_, array_ops.placeholder(dtypes.float32)], axis=1)
977      c_val = tensor_util.constant_value(tf_val, partial=True)
978      self.assertIsNone(c_val)
979
980  def testUnpack_Axis0(self):
981    inputs = np.random.rand(3, 4, 7)
982    tf_vals = array_ops.unstack(inputs)
983    c_vals = [tensor_util.constant_value(x) for x in tf_vals]
984    self.assertAllClose(inputs, c_vals)
985
986  def testUnpack_Partial_Axis0(self):
987    input_ = np.random.rand(4, 7)
988    # This test needs a placeholder which means we need to construct a graph.
989    with ops.Graph().as_default():
990      packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)])
991      tf_vals = array_ops.unstack(packed)
992      c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
993      self.assertAllClose(input_, c_vals[0])
994      self.assertIsNone(c_vals[1])
995
996  def testSplit_Axis0(self):
997    inputs = np.random.rand(6, 5, 7)
998    tf_vals = array_ops.split(inputs, 3)
999    c_vals = [tensor_util.constant_value(x) for x in tf_vals]
1000    self.assertAllClose(np.split(inputs, 3), c_vals)
1001
1002  def testSplit_Partial_Axis0(self):
1003    input_ = np.random.rand(4, 7)
1004    # This test needs a placeholder which means we need to construct a graph.
1005    with ops.Graph().as_default():
1006      placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7))
1007      # it'd be better to use concat here, but concat doesn't support partial
1008      packed = array_ops.stack([input_, placeholder])
1009      tf_vals = array_ops.split(packed, 2)
1010      c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals]
1011      self.assertAllClose(input_, c_vals[0][0])
1012      self.assertIsNone(c_vals[1][0])
1013
1014  def testEqual(self):
1015    # Scalar inputs.
1016    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(1))
1017    self.assertEqual(tensor_util.constant_value(tf_val), True)
1018
1019    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(0))
1020    self.assertEqual(tensor_util.constant_value(tf_val), False)
1021
1022    # Shaped inputs with broadcast semantics.
1023    tf_val = math_ops.equal(constant_op.constant([[0, 1]]),
1024                            constant_op.constant([[0], [1]]))
1025    c_val = tensor_util.constant_value(tf_val)
1026    self.assertAllEqual(c_val, [[True, False], [False, True]])
1027
1028  def testNotEqual(self):
1029    # Scalar inputs.
1030    tf_val = math_ops.not_equal(constant_op.constant(1),
1031                                constant_op.constant(1))
1032    self.assertEqual(tensor_util.constant_value(tf_val), False)
1033
1034    tf_val = math_ops.not_equal(constant_op.constant(1),
1035                                constant_op.constant(0))
1036    self.assertEqual(tensor_util.constant_value(tf_val), True)
1037
1038    # Shaped inputs with broadcast semantics.
1039    tf_val = math_ops.not_equal(constant_op.constant([[0, 1]]),
1040                                constant_op.constant([[0], [1]]))
1041    c_val = tensor_util.constant_value(tf_val)
1042    self.assertAllEqual(c_val, [[False, True], [True, False]])
1043
1044  def testStopGradient(self):
1045    input_ = np.random.rand(4, 7)
1046    tf_val = array_ops.stop_gradient(input_)
1047    c_val = tensor_util.constant_value(tf_val)
1048    self.assertAllEqual(input_, c_val)
1049
1050  def testIdentity(self):
1051    input_ = np.random.rand(4, 7)
1052    tf_val = array_ops.identity(input_)
1053    c_val = tensor_util.constant_value(tf_val)
1054    self.assertAllEqual(input_, c_val)
1055
1056  def testLiteral(self):
1057    x = "hi"
1058    self.assertIs(x, tensor_util.constant_value(x))
1059
1060  def testNumpyNdarray(self):
1061    np_val = np.random.rand(3, 4, 7).astype(np.float32)
1062    self.assertIs(np_val, tensor_util.constant_value(np_val))
1063
1064  def testVariable(self):
1065    var = variables.Variable(1.0, name="variable_node")
1066    self.assertIsNone(tensor_util.constant_value(var))
1067
1068  def testVariableV1(self):
1069    var = variables.VariableV1(1.0, name="variable_node")
1070    self.assertIsNone(tensor_util.constant_value(var))
1071
1072
1073class ConstantValueAsShapeTest(test.TestCase):
1074
1075  @test_util.run_in_graph_and_eager_modes
1076  def testConstant(self):
1077    np_val = np.random.rand(3).astype(np.int32)
1078    tf_val = constant_op.constant(np_val)
1079    self.assertEqual(
1080        tensor_shape.TensorShape(np_val),
1081        tensor_util.constant_value_as_shape(tf_val))
1082
1083    tf_val = constant_op.constant([], dtype=dtypes.int32)
1084    self.assertEqual(
1085        tensor_shape.TensorShape([]),
1086        tensor_util.constant_value_as_shape(tf_val))
1087
1088  @test_util.run_in_graph_and_eager_modes
1089  def testCast(self):
1090    tf_val = math_ops.cast(
1091        array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])),
1092        dtypes.int64)
1093    c_val = tensor_util.constant_value_as_shape(tf_val)
1094    self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val)
1095
1096  @test_util.run_in_graph_and_eager_modes
1097  def testCastWithUnknown(self):
1098    tf_val = math_ops.cast(constant_op.constant([-1, 1, -1]), dtypes.int64)
1099    c_val = tensor_util.constant_value_as_shape(tf_val)
1100    self.assertEqual([None, 1, None], c_val.as_list())
1101
1102  @test_util.run_in_graph_and_eager_modes
1103  def testShape(self):
1104    tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
1105    c_val = tensor_util.constant_value_as_shape(tf_val)
1106    self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val)
1107
1108  @test_util.run_in_graph_and_eager_modes
1109  def testMinusOneBecomesNone(self):
1110    tf_val = constant_op.constant([-1, 1, -1], shape=[3])
1111    c_val = tensor_util.constant_value_as_shape(tf_val)
1112    self.assertEqual([None, 1, None], c_val.as_list())
1113
1114  def testPack(self):
1115    # This test needs a placeholder which means we need to construct a graph.
1116    with ops.Graph().as_default():
1117      tf_val = array_ops.stack(
1118          [constant_op.constant(16), 37,
1119           array_ops.placeholder(dtypes.int32)])
1120      c_val = tensor_util.constant_value_as_shape(tf_val)
1121      self.assertEqual([16, 37, None], c_val.as_list())
1122
1123  def testConcat(self):
1124    # This test needs a placeholder which means we need to construct a graph.
1125    with ops.Graph().as_default():
1126      tf_val = array_ops.concat(
1127          [[16, 37], array_ops.placeholder(dtypes.int32, shape=(2,))], 0)
1128      c_val = tensor_util.constant_value_as_shape(tf_val)
1129      self.assertEqual([16, 37, None, None], c_val.as_list())
1130
1131      tf_val = array_ops.concat(
1132          [[16, 37],
1133           array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
1134      c_val = tensor_util.constant_value_as_shape(tf_val)
1135      self.assertEqual([16, 37, None, 48], c_val.as_list())
1136
1137  def testSlice(self):
1138    # This test needs a placeholder which means we need to construct a graph.
1139    with ops.Graph().as_default():
1140      tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2]
1141      c_val = tensor_util.constant_value_as_shape(tf_val)
1142      self.assertEqual([None, None], c_val.as_list())
1143
1144    # begin:end
1145    tf_val = constant_op.constant([10, 20, 30])[1:3]
1146    c_val = tensor_util.constant_value_as_shape(tf_val)
1147    self.assertEqual([20, 30], c_val.as_list())
1148
1149    # begin:end:stride
1150    tf_val = array_ops.strided_slice(
1151        constant_op.constant([10, 20, 30]), [1], [3], strides=[2])
1152    c_val = tensor_util.constant_value_as_shape(tf_val)
1153    self.assertEqual([20], c_val.as_list())
1154
1155    # [1, 2, 16, 37, None, 48]
1156    # This test needs a placeholder which means we need to construct a graph.
1157    with ops.Graph().as_default():
1158      tf_val_orig = array_ops.concat(
1159          [[1, 2, 16, 37],
1160           array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0)
1161
1162      # begin: no end
1163      tf_val = tf_val_orig[2:]
1164      c_val = tensor_util.constant_value_as_shape(tf_val)
1165      self.assertEqual([16, 37, None, 48], c_val.as_list())
1166
1167      # begin::negative slice
1168      tf_val = tf_val_orig[2::-1]
1169      c_val = tensor_util.constant_value_as_shape(tf_val)
1170      self.assertEqual([16, 2, 1], c_val.as_list())
1171
1172      # :end:negative slice
1173      tf_val = tf_val_orig[:1:-2]
1174      c_val = tensor_util.constant_value_as_shape(tf_val)
1175      self.assertEqual([48, 37], c_val.as_list())
1176
1177      # begin:end:negative slice
1178      tf_val = tf_val_orig[3:1:-1]
1179      c_val = tensor_util.constant_value_as_shape(tf_val)
1180      self.assertEqual([37, 16], c_val.as_list())
1181
1182      # begin:negative end:slice
1183      tf_val = tf_val_orig[1:-3:1]
1184      c_val = tensor_util.constant_value_as_shape(tf_val)
1185      self.assertEqual([2, 16], c_val.as_list())
1186
1187      # negative begin::slice
1188      tf_val = tf_val_orig[-3::1]
1189      c_val = tensor_util.constant_value_as_shape(tf_val)
1190      self.assertEqual([37, None, 48], c_val.as_list())
1191
1192      # negative begin::negative slice
1193      tf_val = tf_val_orig[-3::-1]
1194      c_val = tensor_util.constant_value_as_shape(tf_val)
1195      self.assertEqual([37, 16, 2, 1], c_val.as_list())
1196
1197      # negative begin:negative end:negative slice
1198      tf_val = tf_val_orig[-3:-5:-1]
1199      c_val = tensor_util.constant_value_as_shape(tf_val)
1200      self.assertEqual([37, 16], c_val.as_list())
1201
1202      # Do not support shape inference for additional arguments
1203      tf_val = constant_op.constant([10, 20, 30])[...]
1204      c_val = tensor_util.constant_value_as_shape(tf_val)
1205      self.assertEqual([None, None, None], c_val.as_list())
1206
1207      # Do not support shape inference for tensor slices.
1208      tf_val = constant_op.constant(
1209          [10, 20, 30])[array_ops.placeholder(dtypes.int32, shape=()):]
1210      c_val = tensor_util.constant_value_as_shape(tf_val)
1211      self.assertEqual(tensor_shape.unknown_shape(), c_val)
1212
1213      # Do not support shape inference for higher rank
1214      with self.assertRaises(ValueError):
1215        tf_val = constant_op.constant([[10], [20], [30]])[:, 0:]
1216        c_val = tensor_util.constant_value_as_shape(tf_val)
1217
1218
1219class MaybeSetStaticShapeTest(test.TestCase):
1220
1221  @contextlib.contextmanager
1222  def disableSetStaticShape(self):
1223    flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE
1224    tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False
1225    try:
1226      yield
1227    finally:
1228      tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old
1229
1230  def testMaybeSetStaticShape(self):
1231    shape = constant_op.constant([2, 5], dtype=dtypes.int32)
1232
1233    def reshape():
1234      v = array_ops.zeros([10])
1235      return array_ops.reshape(v, shape)
1236    # This test needs a placeholder which means we need to construct a graph.
1237    with ops.Graph().as_default():
1238      with self.disableSetStaticShape():
1239        graph_without_shape_propagation = func_graph.func_graph_from_py_func(
1240            "without_shape_propagation", reshape, [], {})
1241      graph_with_shape_propagation = func_graph.func_graph_from_py_func(
1242          "with_shape_propagation", reshape, [], {})
1243      self.assertCountEqual(
1244          [op.type for op in graph_without_shape_propagation.get_operations()],
1245          [op.type for op in graph_with_shape_propagation.get_operations()])
1246
1247  def testMaybeSetStaticShapeScalarShape(self):
1248
1249    def reshape():
1250      v = array_ops.placeholder(dtypes.float32)
1251      t = array_ops.reshape(v, [-1])
1252      return t
1253
1254    with self.disableSetStaticShape():
1255      graph_without_shape_propagation = func_graph.func_graph_from_py_func(
1256          "without_shape_propagation", reshape, [], {})
1257    graph_with_shape_propagation = func_graph.func_graph_from_py_func(
1258        "with_shape_propagation", reshape, [], {})
1259    self.assertCountEqual(
1260        [op.type for op in graph_without_shape_propagation.get_operations()],
1261        [op.type for op in graph_with_shape_propagation.get_operations()])
1262
1263
1264class ShapeTensorTest(test_util.TensorFlowTestCase):
1265
1266  @test_util.run_in_graph_and_eager_modes
1267  def testConversion(self):
1268    """Make sure fully known TensorShape objects convert to Tensors."""
1269    shape = tensor_shape.TensorShape([1, tensor_shape.Dimension(2)])
1270    shape_tensor = tensor_util.shape_tensor(shape)
1271    self.assertAllEqual((1, 2), shape_tensor)
1272
1273
1274if __name__ == "__main__":
1275  test.main()
1276