1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for tensor_util.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import sys 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import func_graph 30from tensorflow.python.framework import indexed_slices 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.framework import test_util 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import gen_state_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.ops.ragged import ragged_factory_ops 41from tensorflow.python.platform import test 42 43 44@test_util.run_all_in_graph_and_eager_modes 45class TensorUtilTest(test.TestCase, parameterized.TestCase): 46 47 def testFloat(self): 48 value = 10.0 49 t = tensor_util.make_tensor_proto(value) 50 self.assertProtoEquals(""" 51 dtype: DT_FLOAT 52 tensor_shape {} 53 float_val: %.1f 54 """ % value, t) 55 a = tensor_util.MakeNdarray(t) 56 self.assertEqual(np.float32, a.dtype) 57 self.assertAllClose(np.array(value, dtype=np.float32), a) 58 59 def testFloatN(self): 60 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0]) 61 if sys.byteorder == "big": 62 self.assertProtoEquals(r""" 63 dtype: DT_FLOAT 64 tensor_shape { dim { size: 3 } } 65 tensor_content: "A \000\000A\240\000\000A\360\000\000" 66 """, t) 67 else: 68 self.assertProtoEquals(r""" 69 dtype: DT_FLOAT 70 tensor_shape { dim { size: 3 } } 71 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 72 """, t) 73 a = tensor_util.MakeNdarray(t) 74 self.assertEqual(np.float32, a.dtype) 75 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 76 77 def testFloatTyped(self): 78 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) 79 if sys.byteorder == "big": 80 self.assertProtoEquals(r""" 81 dtype: DT_FLOAT 82 tensor_shape { dim { size: 3 } } 83 tensor_content: "A \000\000A\240\000\000A\360\000\000" 84 """, t) 85 else: 86 self.assertProtoEquals(r""" 87 dtype: DT_FLOAT 88 tensor_shape { dim { size: 3 } } 89 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 90 """, t) 91 a = tensor_util.MakeNdarray(t) 92 self.assertEqual(np.float32, a.dtype) 93 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 94 95 def testFloatTypeCoerce(self): 96 t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32) 97 if sys.byteorder == "big": 98 self.assertProtoEquals(r""" 99 dtype: DT_FLOAT 100 tensor_shape { dim { size: 3 } } 101 tensor_content: "A \000\000A\240\000\000A\360\000\000" 102 """, t) 103 else: 104 self.assertProtoEquals(r""" 105 dtype: DT_FLOAT 106 tensor_shape { dim { size: 3 } } 107 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 108 """, t) 109 a = tensor_util.MakeNdarray(t) 110 self.assertEqual(np.float32, a.dtype) 111 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 112 113 def testFloatTypeCoerceNdarray(self): 114 arr = np.asarray([10, 20, 30], dtype="int") 115 t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32) 116 if sys.byteorder == "big": 117 self.assertProtoEquals(r""" 118 dtype: DT_FLOAT 119 tensor_shape { dim { size: 3 } } 120 tensor_content: "A \000\000A\240\000\000A\360\000\000" 121 """, t) 122 else: 123 self.assertProtoEquals(r""" 124 dtype: DT_FLOAT 125 tensor_shape { dim { size: 3 } } 126 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 127 """, t) 128 a = tensor_util.MakeNdarray(t) 129 self.assertEqual(np.float32, a.dtype) 130 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 131 132 def testFloatSizes(self): 133 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3]) 134 if sys.byteorder == "big": 135 self.assertProtoEquals(r""" 136 dtype: DT_FLOAT 137 tensor_shape { dim { size: 1 } dim { size: 3 } } 138 tensor_content: "A \000\000A\240\000\000A\360\000\000" 139 """, t) 140 else: 141 self.assertProtoEquals(r""" 142 dtype: DT_FLOAT 143 tensor_shape { dim { size: 1 } dim { size: 3 } } 144 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 145 """, t) 146 a = tensor_util.MakeNdarray(t) 147 self.assertEqual(np.float32, a.dtype) 148 self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a) 149 150 def testFloatSizes2(self): 151 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1]) 152 if sys.byteorder == "big": 153 self.assertProtoEquals(r""" 154 dtype: DT_FLOAT 155 tensor_shape { dim { size: 3 } dim { size: 1 } } 156 tensor_content: "A \000\000A\240\000\000A\360\000\000" 157 """, t) 158 else: 159 self.assertProtoEquals(r""" 160 dtype: DT_FLOAT 161 tensor_shape { dim { size: 3 } dim { size: 1 } } 162 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 163 """, t) 164 a = tensor_util.MakeNdarray(t) 165 self.assertEqual(np.float32, a.dtype) 166 self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32), a) 167 168 def testFloatSizesLessValues(self): 169 t = tensor_util.make_tensor_proto(10.0, shape=[1, 3]) 170 self.assertProtoEquals(""" 171 dtype: DT_FLOAT 172 tensor_shape { dim { size: 1 } dim { size: 3 } } 173 float_val: 10.0 174 """, t) 175 # No conversion to Ndarray for this one: not enough values. 176 177 def testFloatNpArrayFloat64(self): 178 t = tensor_util.make_tensor_proto( 179 np.array([[10.0, 20.0, 30.0]], dtype=np.float64)) 180 if sys.byteorder == "big": 181 self.assertProtoEquals(r""" 182 dtype: DT_DOUBLE 183 tensor_shape { dim { size: 1 } dim { size: 3 } } 184 tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000" 185 """, t) 186 else: 187 self.assertProtoEquals(r""" 188 dtype: DT_DOUBLE 189 tensor_shape { dim { size: 1 } dim { size: 3 } } 190 tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@" 191 """, t) 192 a = tensor_util.MakeNdarray(t) 193 self.assertEqual(np.float64, a.dtype) 194 self.assertAllClose( 195 np.array([[10.0, 20.0, 30.0]], dtype=np.float64), 196 tensor_util.MakeNdarray(t)) 197 198 def testFloatTypesWithImplicitRepeat(self): 199 for dtype, nptype in [(dtypes.float32, np.float32), 200 (dtypes.float64, np.float64)]: 201 t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype) 202 a = tensor_util.MakeNdarray(t) 203 self.assertAllClose( 204 np.array( 205 [[10.0, 10.0, 10.0, 10.0], 206 [10.0, 10.0, 10.0, 10.0], 207 [10.0, 10.0, 10.0, 10.0]], 208 dtype=nptype), 209 a) 210 211 def testFloatMutateArray(self): 212 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) 213 a = tensor_util.MakeNdarray(t) 214 a[0] = 5.0 215 self.assertEqual(np.float32, a.dtype) 216 self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a) 217 if sys.byteorder == "big": 218 self.assertProtoEquals(r""" 219 dtype: DT_FLOAT 220 tensor_shape { dim { size: 3 } } 221 tensor_content: "A \000\000A\240\000\000A\360\000\000" 222 """, t) 223 else: 224 self.assertProtoEquals(r""" 225 dtype: DT_FLOAT 226 tensor_shape { dim { size: 3 } } 227 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 228 """, t) 229 230 def testHalf(self): 231 t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16)) 232 self.assertProtoEquals( 233 """ 234 dtype: DT_HALF 235 tensor_shape { dim { size: 2 } } 236 tensor_content: "\000I\000M" 237 """, t) 238 239 a = tensor_util.MakeNdarray(t) 240 self.assertEqual(np.float16, a.dtype) 241 self.assertAllClose(np.array([10.0, 20.0], dtype=np.float16), a) 242 243 def testBfloat16(self): 244 test_type = dtypes.bfloat16.as_numpy_dtype 245 t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) 246 # 10.0: 16672 = 010000010(130) 0100000: (1+0/2+1/4) * 2^(130-127) 247 # 20.0: 16800 = 010000011(131) 0100000: (1+0/2+1/4) * 2^(131-127) 248 self.assertProtoEquals(""" 249 dtype: DT_BFLOAT16 250 tensor_shape { 251 dim { 252 size: 2 253 } 254 } 255 half_val: 16672 256 half_val: 16800 257 """, t) 258 259 a = tensor_util.MakeNdarray(t) 260 self.assertEqual(test_type, a.dtype) 261 self.assertAllClose(np.array([10.0, 20.0], dtype=test_type), a) 262 263 def testInt(self): 264 t = tensor_util.make_tensor_proto(10) 265 self.assertProtoEquals(""" 266 dtype: DT_INT32 267 tensor_shape {} 268 int_val: 10 269 """, t) 270 a = tensor_util.MakeNdarray(t) 271 self.assertEqual(np.int32, a.dtype) 272 self.assertAllClose(np.array(10, dtype=np.int32), a) 273 274 def testLargeInt(self): 275 value = np.iinfo(np.int64).max 276 t = tensor_util.make_tensor_proto(value) 277 self.assertProtoEquals(""" 278 dtype: DT_INT64 279 tensor_shape {} 280 int64_val: %d 281 """ % value, t) 282 a = tensor_util.MakeNdarray(t) 283 self.assertEqual(np.int64, a.dtype) 284 self.assertAllClose(np.array(value, dtype=np.int64), a) 285 286 def testLargeNegativeInt(self): 287 # We don't use the min np.int64 value here 288 # because it breaks np.abs(). 289 # 290 # np.iinfo(np.int64).min = -9223372036854775808 291 # np.iinfo(np.int64).max = 9223372036854775807 292 # np.abs(-9223372036854775808) = -9223372036854775808 293 value = np.iinfo(np.int64).min + 1 294 t = tensor_util.make_tensor_proto(value) 295 self.assertProtoEquals(""" 296 dtype: DT_INT64 297 tensor_shape {} 298 int64_val: %d 299 """ % value, t) 300 a = tensor_util.MakeNdarray(t) 301 self.assertEqual(np.int64, a.dtype) 302 self.assertAllClose(np.array(value, dtype=np.int64), a) 303 304 def testIntNDefaultType(self): 305 t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) 306 if sys.byteorder == "big": 307 self.assertProtoEquals(r""" 308 dtype: DT_INT32 309 tensor_shape { dim { size: 2 } dim { size: 2 } } 310 tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000(" 311 """, t) 312 else: 313 self.assertProtoEquals(r""" 314 dtype: DT_INT32 315 tensor_shape { dim { size: 2 } dim { size: 2 } } 316 tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000" 317 """, t) 318 a = tensor_util.MakeNdarray(t) 319 self.assertEqual(np.int32, a.dtype) 320 self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a) 321 322 @parameterized.named_parameters( 323 ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16), 324 ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64), 325 ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16), 326 ("_uint32", dtypes.uint32, np.uint32), 327 ("_uint64", dtypes.uint64, np.uint64)) 328 def testIntTypes(self, dtype, nptype): 329 # Test with array. 330 t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype) 331 self.assertEqual(dtype, t.dtype) 332 self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) 333 a = tensor_util.MakeNdarray(t) 334 self.assertEqual(nptype, a.dtype) 335 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 336 # Test with ndarray. 337 t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype)) 338 self.assertEqual(dtype, t.dtype) 339 self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) 340 a = tensor_util.MakeNdarray(t) 341 self.assertEqual(nptype, a.dtype) 342 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 343 344 @parameterized.named_parameters( 345 ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16), 346 ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64), 347 ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16), 348 ("_uint32", dtypes.uint32, np.uint32), 349 ("_uint64", dtypes.uint64, np.uint64)) 350 def testIntTypesWithImplicitRepeat(self, dtype, nptype): 351 self.assertAllEqual( 352 np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]], 353 dtype=nptype), 354 tensor_util.MakeNdarray( 355 tensor_util.make_tensor_proto([10, 11, 12], 356 shape=[3, 4], 357 dtype=dtype))) 358 359 def testIntMixedWithDimension(self): 360 # Github issue: 11974 361 dtype = dtypes.int32 362 nptype = np.int32 363 t = tensor_util.make_tensor_proto( 364 [10, tensor_shape.Dimension(20), 30], dtype=dtype) 365 self.assertEqual(dtype, t.dtype) 366 a = tensor_util.MakeNdarray(t) 367 self.assertEqual(nptype, a.dtype) 368 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 369 370 @parameterized.named_parameters( 371 ("_int64", dtypes.int64, np.int64, "DT_INT64", "int64_val"), 372 ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64", "uint64_val")) 373 def testLong(self, dtype, nptype, proto_dtype, proto_value_name): 374 t = tensor_util.make_tensor_proto(10, dtype=dtype) 375 self.assertProtoEquals( 376 """ 377 dtype: %s 378 tensor_shape {} 379 %s: 10 380 """ % (proto_dtype, proto_value_name), t) 381 a = tensor_util.MakeNdarray(t) 382 self.assertEqual(nptype, a.dtype) 383 self.assertAllClose(np.array(10, dtype=nptype), a) 384 385 @parameterized.named_parameters( 386 ("_int64", dtypes.int64, np.int64, "DT_INT64"), 387 ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64")) 388 def testLongN(self, dtype, nptype, proto_dtype): 389 t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], dtype=dtype) 390 if sys.byteorder == "big": 391 # pylint: disable=line-too-long 392 self.assertProtoEquals( 393 r""" 394 dtype: %s 395 tensor_shape { dim { size: 1 } dim { size: 3 } } 396 tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" 397 """ % proto_dtype, t) 398 # pylint: enable=line-too-long 399 else: 400 # pylint: disable=line-too-long 401 self.assertProtoEquals( 402 r""" 403 dtype: %s 404 tensor_shape { dim { size: 1 } dim { size: 3 } } 405 tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" 406 """ % proto_dtype, t) 407 # pylint: enable=line-too-long 408 a = tensor_util.MakeNdarray(t) 409 self.assertEqual(nptype, a.dtype) 410 self.assertAllClose(np.array([[10, 20, 30]], dtype=nptype), a) 411 412 @parameterized.named_parameters(("_int64", np.int64, "DT_INT64"), 413 ("_uint64", np.uint64, "DT_UINT64")) 414 def testLongNpArray(self, nptype, proto_dtype): 415 t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype)) 416 if sys.byteorder == "big": 417 # pylint: disable=line-too-long 418 self.assertProtoEquals( 419 r""" 420 dtype: %s 421 tensor_shape { dim { size: 3 } } 422 tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" 423 """ % proto_dtype, t) 424 # pylint: enable=line-too-long 425 else: 426 # pylint: disable=line-too-long 427 self.assertProtoEquals( 428 r""" 429 dtype: %s 430 tensor_shape { dim { size: 3 } } 431 tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" 432 """ % proto_dtype, t) 433 # pylint: enable=line-too-long 434 a = tensor_util.MakeNdarray(t) 435 self.assertEqual(nptype, a.dtype) 436 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 437 438 def testQuantizedTypes(self): 439 # Test with array. 440 data = [(21,), (22,), (23,)] 441 442 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32) 443 if sys.byteorder == "big": 444 self.assertProtoEquals(r""" 445 dtype: DT_QINT32 446 tensor_shape { dim { size: 3 } } 447 tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027" 448 """, t) 449 else: 450 self.assertProtoEquals(r""" 451 dtype: DT_QINT32 452 tensor_shape { dim { size: 3 } } 453 tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000" 454 """, t) 455 a = tensor_util.MakeNdarray(t) 456 self.assertEqual(dtypes.qint32.as_numpy_dtype, a.dtype) 457 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 458 459 t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8) 460 self.assertProtoEquals(r""" 461 dtype: DT_QUINT8 462 tensor_shape { dim { size: 3 } } 463 tensor_content: "\025\026\027" 464 """, t) 465 a = tensor_util.MakeNdarray(t) 466 self.assertEqual(dtypes.quint8.as_numpy_dtype, a.dtype) 467 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 468 469 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8) 470 self.assertProtoEquals(r""" 471 dtype: DT_QINT8 472 tensor_shape { dim { size: 3 } } 473 tensor_content: "\025\026\027" 474 """, t) 475 a = tensor_util.MakeNdarray(t) 476 self.assertEqual(dtypes.qint8.as_numpy_dtype, a.dtype) 477 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 478 479 t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16) 480 if sys.byteorder == "big": 481 self.assertProtoEquals(r""" 482 dtype: DT_QUINT16 483 tensor_shape { dim { size: 3 } } 484 tensor_content: "\000\025\000\026\000\027" 485 """, t) 486 else: 487 self.assertProtoEquals(r""" 488 dtype: DT_QUINT16 489 tensor_shape { dim { size: 3 } } 490 tensor_content: "\025\000\026\000\027\000" 491 """, t) 492 a = tensor_util.MakeNdarray(t) 493 self.assertEqual(dtypes.quint16.as_numpy_dtype, a.dtype) 494 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 495 496 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16) 497 if sys.byteorder == "big": 498 self.assertProtoEquals(r""" 499 dtype: DT_QINT16 500 tensor_shape { dim { size: 3 } } 501 tensor_content: "\000\025\000\026\000\027" 502 """, t) 503 else: 504 self.assertProtoEquals(r""" 505 dtype: DT_QINT16 506 tensor_shape { dim { size: 3 } } 507 tensor_content: "\025\000\026\000\027\000" 508 """, t) 509 a = tensor_util.MakeNdarray(t) 510 self.assertEqual(dtypes.qint16.as_numpy_dtype, a.dtype) 511 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 512 513 def testString(self): 514 t = tensor_util.make_tensor_proto("foo") 515 self.assertProtoEquals(""" 516 dtype: DT_STRING 517 tensor_shape {} 518 string_val: "foo" 519 """, t) 520 a = tensor_util.MakeNdarray(t) 521 self.assertEqual(np.object, a.dtype) 522 self.assertEqual([b"foo"], a) 523 524 def testStringWithImplicitRepeat(self): 525 t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4]) 526 a = tensor_util.MakeNdarray(t) 527 self.assertAllEqual( 528 np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"], 529 [b"g", b"g", b"g", b"g"]], 530 dtype=np.object), a) 531 532 def testStringN(self): 533 t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3]) 534 self.assertProtoEquals(""" 535 dtype: DT_STRING 536 tensor_shape { dim { size: 1 } dim { size: 3 } } 537 string_val: "foo" 538 string_val: "bar" 539 string_val: "baz" 540 """, t) 541 a = tensor_util.MakeNdarray(t) 542 self.assertEqual(np.object, a.dtype) 543 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 544 545 def testStringNpArray(self): 546 t = tensor_util.make_tensor_proto( 547 np.array([[b"a", b"ab"], [b"abc", b"abcd"]])) 548 self.assertProtoEquals(""" 549 dtype: DT_STRING 550 tensor_shape { dim { size: 2 } dim { size: 2 } } 551 string_val: "a" 552 string_val: "ab" 553 string_val: "abc" 554 string_val: "abcd" 555 """, t) 556 a = tensor_util.MakeNdarray(t) 557 self.assertEqual(np.object, a.dtype) 558 self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a) 559 560 def testArrayMethod(self): 561 562 class Wrapper(object): 563 564 def __array__(self): 565 return np.array([b"foo", b"bar", b"baz"]) 566 567 t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3]) 568 self.assertProtoEquals(""" 569 dtype: DT_STRING 570 tensor_shape { dim { size: 1 } dim { size: 3 } } 571 string_val: "foo" 572 string_val: "bar" 573 string_val: "baz" 574 """, t) 575 a = tensor_util.MakeNdarray(t) 576 self.assertEqual(np.object, a.dtype) 577 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 578 579 def testArrayInterface(self): 580 581 class Wrapper(object): 582 583 @property 584 def __array_interface__(self): 585 return np.array([b"foo", b"bar", b"baz"]).__array_interface__ 586 587 t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3]) 588 self.assertProtoEquals(""" 589 dtype: DT_STRING 590 tensor_shape { dim { size: 1 } dim { size: 3 } } 591 string_val: "foo" 592 string_val: "bar" 593 string_val: "baz" 594 """, t) 595 a = tensor_util.MakeNdarray(t) 596 self.assertEqual(np.object, a.dtype) 597 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 598 599 def testStringTuple(self): 600 t = tensor_util.make_tensor_proto((b"a", b"ab", b"abc", b"abcd")) 601 self.assertProtoEquals(""" 602 dtype: DT_STRING 603 tensor_shape { dim { size: 4 } } 604 string_val: "a" 605 string_val: "ab" 606 string_val: "abc" 607 string_val: "abcd" 608 """, t) 609 a = tensor_util.MakeNdarray(t) 610 self.assertEqual(np.object, a.dtype) 611 self.assertAllEqual(np.array((b"a", b"ab", b"abc", b"abcd")), a) 612 613 def testStringNestedTuple(self): 614 t = tensor_util.make_tensor_proto(((b"a", b"ab"), (b"abc", b"abcd"))) 615 self.assertProtoEquals(""" 616 dtype: DT_STRING 617 tensor_shape { dim { size: 2 } dim { size: 2 } } 618 string_val: "a" 619 string_val: "ab" 620 string_val: "abc" 621 string_val: "abcd" 622 """, t) 623 a = tensor_util.MakeNdarray(t) 624 self.assertEqual(np.object, a.dtype) 625 self.assertAllEqual(np.array(((b"a", b"ab"), (b"abc", b"abcd"))), a) 626 627 def testComplex64(self): 628 t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex64) 629 self.assertProtoEquals(""" 630 dtype: DT_COMPLEX64 631 tensor_shape {} 632 scomplex_val: 1 633 scomplex_val: 2 634 """, t) 635 a = tensor_util.MakeNdarray(t) 636 self.assertEqual(np.complex64, a.dtype) 637 self.assertAllEqual(np.array(1 + 2j), a) 638 639 def testComplex128(self): 640 t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex128) 641 self.assertProtoEquals(""" 642 dtype: DT_COMPLEX128 643 tensor_shape {} 644 dcomplex_val: 1 645 dcomplex_val: 2 646 """, t) 647 a = tensor_util.MakeNdarray(t) 648 self.assertEqual(np.complex128, a.dtype) 649 self.assertAllEqual(np.array(1 + 2j), a) 650 651 def testComplexWithImplicitRepeat(self): 652 for dtype, np_dtype in [(dtypes.complex64, np.complex64), 653 (dtypes.complex128, np.complex128)]: 654 t = tensor_util.make_tensor_proto((1 + 1j), shape=[3, 4], dtype=dtype) 655 a = tensor_util.MakeNdarray(t) 656 self.assertAllClose( 657 np.array( 658 [[(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)], 659 [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)], 660 [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)]], 661 dtype=np_dtype), 662 a) 663 664 def testComplex64N(self): 665 t = tensor_util.make_tensor_proto( 666 [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex64) 667 self.assertProtoEquals(""" 668 dtype: DT_COMPLEX64 669 tensor_shape { dim { size: 1 } dim { size: 3 } } 670 scomplex_val: 1 671 scomplex_val: 2 672 scomplex_val: 3 673 scomplex_val: 4 674 scomplex_val: 5 675 scomplex_val: 6 676 """, t) 677 a = tensor_util.MakeNdarray(t) 678 self.assertEqual(np.complex64, a.dtype) 679 self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a) 680 681 def testComplex128N(self): 682 t = tensor_util.make_tensor_proto( 683 [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex128) 684 self.assertProtoEquals(""" 685 dtype: DT_COMPLEX128 686 tensor_shape { dim { size: 1 } dim { size: 3 } } 687 dcomplex_val: 1 688 dcomplex_val: 2 689 dcomplex_val: 3 690 dcomplex_val: 4 691 dcomplex_val: 5 692 dcomplex_val: 6 693 """, t) 694 a = tensor_util.MakeNdarray(t) 695 self.assertEqual(np.complex128, a.dtype) 696 self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a) 697 698 def testComplex64NpArray(self): 699 t = tensor_util.make_tensor_proto( 700 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), 701 dtype=dtypes.complex64) 702 # scomplex_val are real_0, imag_0, real_1, imag_1, ... 703 self.assertProtoEquals(""" 704 dtype: DT_COMPLEX64 705 tensor_shape { dim { size: 2 } dim { size: 2 } } 706 scomplex_val: 1 707 scomplex_val: 2 708 scomplex_val: 3 709 scomplex_val: 4 710 scomplex_val: 5 711 scomplex_val: 6 712 scomplex_val: 7 713 scomplex_val: 8 714 """, t) 715 a = tensor_util.MakeNdarray(t) 716 self.assertEqual(np.complex64, a.dtype) 717 self.assertAllEqual( 718 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a) 719 720 def testComplex128NpArray(self): 721 t = tensor_util.make_tensor_proto( 722 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), 723 dtype=dtypes.complex128) 724 # scomplex_val are real_0, imag_0, real_1, imag_1, ... 725 self.assertProtoEquals(""" 726 dtype: DT_COMPLEX128 727 tensor_shape { dim { size: 2 } dim { size: 2 } } 728 dcomplex_val: 1 729 dcomplex_val: 2 730 dcomplex_val: 3 731 dcomplex_val: 4 732 dcomplex_val: 5 733 dcomplex_val: 6 734 dcomplex_val: 7 735 dcomplex_val: 8 736 """, t) 737 a = tensor_util.MakeNdarray(t) 738 self.assertEqual(np.complex128, a.dtype) 739 self.assertAllEqual( 740 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a) 741 742 def testNestedNumpyArrayWithoutDType(self): 743 t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)]) 744 a = tensor_util.MakeNdarray(t) 745 self.assertEqual(np.float32, a.dtype) 746 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 747 748 def testNestedNumpyArrayWithDType(self): 749 t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)], 750 dtype=dtypes.float32) 751 a = tensor_util.MakeNdarray(t) 752 self.assertEqual(np.float32, a.dtype) 753 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 754 755 def testUnsupportedDTypes(self): 756 with self.assertRaises(TypeError): 757 tensor_util.make_tensor_proto(np.array([1]), 0) 758 with self.assertRaises(TypeError): 759 tensor_util.make_tensor_proto(3, dtype=dtypes.qint8) 760 with self.assertRaises(TypeError): 761 tensor_util.make_tensor_proto([3], dtype=dtypes.qint8) 762 763 # Validate the helpful error message when trying to convert an 764 # unconvertible list as strings. 765 with self.assertRaisesRegex(TypeError, "Failed to convert object"): 766 tensor_util.make_tensor_proto([tensor_shape.Dimension(1)]) 767 768 def testTensorShapeVerification(self): 769 array = np.array([[1], [2]]) 770 correct_shape = (2, 1) 771 incorrect_shape = (1, 2) 772 tensor_util.make_tensor_proto(array, shape=correct_shape, verify_shape=True) 773 with self.assertRaises(TypeError): 774 tensor_util.make_tensor_proto( 775 array, shape=incorrect_shape, verify_shape=True) 776 777 def testShapeTooLarge(self): 778 with self.assertRaises(ValueError): 779 tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1]) 780 781 def testLowRankSupported(self): 782 t = tensor_util.make_tensor_proto(np.array(7)) 783 self.assertProtoEquals(""" 784 dtype: DT_INT64 785 tensor_shape {} 786 int64_val: 7 787 """, t) 788 789 def testShapeEquals(self): 790 t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) 791 self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) 792 self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) 793 self.assertTrue( 794 tensor_util.ShapeEquals(t, tensor_shape.as_shape([2, 2]).as_proto())) 795 self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) 796 self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) 797 self.assertFalse(tensor_util.ShapeEquals(t, [4])) 798 799 800@test_util.run_all_in_graph_and_eager_modes 801class IsTensorTest(test.TestCase): 802 803 def testConstantTensor(self): 804 np_val = np.random.rand(3).astype(np.int32) 805 tf_val = constant_op.constant(np_val) 806 self.assertFalse(tensor_util.is_tf_type(np_val)) 807 self.assertTrue(tensor_util.is_tf_type(tf_val)) 808 809 def testRaggedTensor(self): 810 rt = ragged_factory_ops.constant([[1, 2], [3]]) 811 rt_value = self.evaluate(rt) 812 self.assertTrue(tensor_util.is_tf_type(rt)) 813 self.assertFalse(tensor_util.is_tf_type(rt_value)) 814 815 def testSparseTensor(self): 816 st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10]) 817 st_value = self.evaluate(st) 818 self.assertTrue(tensor_util.is_tf_type(st)) 819 self.assertFalse(tensor_util.is_tf_type(st_value)) 820 821 def testIndexedSlices(self): 822 x = indexed_slices.IndexedSlices( 823 constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30])) 824 x_value = indexed_slices.IndexedSlicesValue( 825 np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100])) 826 self.assertTrue(tensor_util.is_tf_type(x)) 827 self.assertFalse(tensor_util.is_tf_type(x_value)) 828 829 def testVariable(self): 830 v = variables.Variable([1, 2, 3]) 831 self.assertTrue(tensor_util.is_tf_type(v)) 832 833 834class ConstantValueTest(test.TestCase): 835 836 def testConstant(self): 837 np_val = np.random.rand(3, 4, 7).astype(np.float32) 838 tf_val = constant_op.constant(np_val) 839 self.assertAllClose(np_val, tensor_util.constant_value(tf_val)) 840 841 np_val = np.random.rand(3, 0, 7).astype(np.float32) 842 tf_val = constant_op.constant(np_val) 843 self.assertAllClose(np_val, tensor_util.constant_value(tf_val)) 844 845 def testUnknown(self): 846 with ops.Graph().as_default(): 847 tf_val = gen_state_ops.variable( 848 shape=[3, 4, 7], 849 dtype=dtypes.float32, 850 name="tf_val", 851 container="", 852 shared_name="") 853 self.assertIs(None, tensor_util.constant_value(tf_val)) 854 855 def testShape(self): 856 np_val = np.array([1, 2, 3], dtype=np.int32) 857 tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) 858 c_val = tensor_util.constant_value(tf_val) 859 self.assertAllEqual(np_val, c_val) 860 self.assertEqual(np.int32, c_val.dtype) 861 862 def testFill(self): 863 np_val = np.array([-1, -1, -1], dtype=np.float32) 864 tf_val = array_ops.fill([3], constant_op.constant(-1.0)) 865 c_val = tensor_util.constant_value(tf_val) 866 self.assertAllEqual(np_val, c_val) 867 self.assertEqual(np.float32, c_val.dtype) 868 869 def testSize(self): 870 tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3])) 871 c_val = tensor_util.constant_value(tf_val) 872 self.assertEqual(6, c_val) 873 874 def testSizeOfScalar(self): 875 tf_val = array_ops.size(constant_op.constant(0.0)) 876 c_val = tensor_util.constant_value(tf_val) 877 self.assertEqual(1, c_val) 878 self.assertIn(type(c_val), [np.ndarray, np.int32]) 879 880 def testRank(self): 881 tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3])) 882 c_val = tensor_util.constant_value(tf_val) 883 884 self.assertIn(type(c_val), [np.ndarray, np.int32]) 885 self.assertEqual((), c_val.shape) 886 self.assertEqual(3, c_val) 887 888 # Repeat test using array_ops.rank_internal to avoid the optimization that 889 # happens in the rank function. 890 tf_val = array_ops.rank_internal( 891 constant_op.constant( 892 0.0, shape=[1, 2, 3]), optimize=False) 893 c_val = tensor_util.constant_value(tf_val) 894 895 self.assertIn(type(c_val), [np.ndarray, np.int32]) 896 self.assertEqual((), c_val.shape) 897 self.assertEqual(3, c_val) 898 self.assertEqual([3], c_val) 899 900 def testCast(self): 901 np_val = np.random.rand(3, 4, 7).astype(np.float32) 902 tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) 903 c_val = tensor_util.constant_value(tf_val) 904 self.assertAllClose(np_val.astype(np.float64), c_val) 905 906 np_val = np.random.rand(3, 0, 7).astype(np.float32) 907 tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) 908 c_val = tensor_util.constant_value(tf_val) 909 self.assertAllClose(np_val.astype(np.float64), c_val) 910 911 def testConcat(self): 912 np_val = np.random.rand(3, 4, 7).astype(np.float32) 913 tf_val = array_ops.concat( 914 [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]], 0) 915 c_val = tensor_util.constant_value(tf_val) 916 self.assertAllClose(np_val, c_val) 917 918 # This test needs a placeholder which means we need to construct a graph. 919 with ops.Graph().as_default(): 920 tf_val = array_ops.concat( 921 [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]], 922 array_ops.placeholder(dtypes.int32)) 923 c_val = tensor_util.constant_value(tf_val) 924 self.assertIs(None, c_val) 925 926 tf_val = array_ops.concat([ 927 np_val[0, :, :], 928 array_ops.placeholder(dtypes.float32), np_val[2, :, :] 929 ], 1) 930 c_val = tensor_util.constant_value(tf_val) 931 self.assertIs(None, c_val) 932 933 def testPack_Axis0(self): 934 inputs = [np.random.rand(4, 7) for _ in range(3)] 935 np_val = np.array(inputs) 936 tf_val = array_ops.stack(inputs) 937 c_val = tensor_util.constant_value(tf_val) 938 self.assertAllClose(np_val, c_val) 939 940 # This test needs a placeholder which means we need to construct a graph. 941 with ops.Graph().as_default(): 942 tf_val = array_ops.stack( 943 [inputs[0], 944 array_ops.placeholder(dtypes.float32), inputs[2]]) 945 c_val = tensor_util.constant_value(tf_val) 946 self.assertIs(None, c_val) 947 948 def testPack_Axis1(self): 949 # This test needs a placeholder which means we need to construct a graph. 950 with ops.Graph().as_default(): 951 inputs = [np.random.rand(4, 7) for _ in range(3)] 952 tf_val = array_ops.stack(inputs, axis=1) 953 c_val = tensor_util.constant_value(tf_val) 954 self.assertIsNone(c_val) 955 956 tf_val = array_ops.stack( 957 [inputs[0], 958 array_ops.placeholder(dtypes.float32), inputs[2]], axis=1) 959 c_val = tensor_util.constant_value(tf_val) 960 self.assertIs(None, c_val) 961 962 def testPack_Partial_Axis0(self): 963 input_ = np.random.rand(4, 7) 964 # This test needs a placeholder which means we need to construct a graph. 965 with ops.Graph().as_default(): 966 tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)]) 967 c_val = tensor_util.constant_value(tf_val, partial=True) 968 self.assertAllClose(input_, c_val[0]) 969 self.assertIsNone(c_val[1]) 970 971 def testPack_Partial_Axis1(self): 972 input_ = np.random.rand(4, 7) 973 # This test needs a placeholder which means we need to construct a graph. 974 with ops.Graph().as_default(): 975 tf_val = array_ops.stack( 976 [input_, array_ops.placeholder(dtypes.float32)], axis=1) 977 c_val = tensor_util.constant_value(tf_val, partial=True) 978 self.assertIsNone(c_val) 979 980 def testUnpack_Axis0(self): 981 inputs = np.random.rand(3, 4, 7) 982 tf_vals = array_ops.unstack(inputs) 983 c_vals = [tensor_util.constant_value(x) for x in tf_vals] 984 self.assertAllClose(inputs, c_vals) 985 986 def testUnpack_Partial_Axis0(self): 987 input_ = np.random.rand(4, 7) 988 # This test needs a placeholder which means we need to construct a graph. 989 with ops.Graph().as_default(): 990 packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)]) 991 tf_vals = array_ops.unstack(packed) 992 c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals] 993 self.assertAllClose(input_, c_vals[0]) 994 self.assertIsNone(c_vals[1]) 995 996 def testSplit_Axis0(self): 997 inputs = np.random.rand(6, 5, 7) 998 tf_vals = array_ops.split(inputs, 3) 999 c_vals = [tensor_util.constant_value(x) for x in tf_vals] 1000 self.assertAllClose(np.split(inputs, 3), c_vals) 1001 1002 def testSplit_Partial_Axis0(self): 1003 input_ = np.random.rand(4, 7) 1004 # This test needs a placeholder which means we need to construct a graph. 1005 with ops.Graph().as_default(): 1006 placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7)) 1007 # it'd be better to use concat here, but concat doesn't support partial 1008 packed = array_ops.stack([input_, placeholder]) 1009 tf_vals = array_ops.split(packed, 2) 1010 c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals] 1011 self.assertAllClose(input_, c_vals[0][0]) 1012 self.assertIsNone(c_vals[1][0]) 1013 1014 def testEqual(self): 1015 # Scalar inputs. 1016 tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(1)) 1017 self.assertEqual(tensor_util.constant_value(tf_val), True) 1018 1019 tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(0)) 1020 self.assertEqual(tensor_util.constant_value(tf_val), False) 1021 1022 # Shaped inputs with broadcast semantics. 1023 tf_val = math_ops.equal(constant_op.constant([[0, 1]]), 1024 constant_op.constant([[0], [1]])) 1025 c_val = tensor_util.constant_value(tf_val) 1026 self.assertAllEqual(c_val, [[True, False], [False, True]]) 1027 1028 def testNotEqual(self): 1029 # Scalar inputs. 1030 tf_val = math_ops.not_equal(constant_op.constant(1), 1031 constant_op.constant(1)) 1032 self.assertEqual(tensor_util.constant_value(tf_val), False) 1033 1034 tf_val = math_ops.not_equal(constant_op.constant(1), 1035 constant_op.constant(0)) 1036 self.assertEqual(tensor_util.constant_value(tf_val), True) 1037 1038 # Shaped inputs with broadcast semantics. 1039 tf_val = math_ops.not_equal(constant_op.constant([[0, 1]]), 1040 constant_op.constant([[0], [1]])) 1041 c_val = tensor_util.constant_value(tf_val) 1042 self.assertAllEqual(c_val, [[False, True], [True, False]]) 1043 1044 def testStopGradient(self): 1045 input_ = np.random.rand(4, 7) 1046 tf_val = array_ops.stop_gradient(input_) 1047 c_val = tensor_util.constant_value(tf_val) 1048 self.assertAllEqual(input_, c_val) 1049 1050 def testIdentity(self): 1051 input_ = np.random.rand(4, 7) 1052 tf_val = array_ops.identity(input_) 1053 c_val = tensor_util.constant_value(tf_val) 1054 self.assertAllEqual(input_, c_val) 1055 1056 def testLiteral(self): 1057 x = "hi" 1058 self.assertIs(x, tensor_util.constant_value(x)) 1059 1060 def testNumpyNdarray(self): 1061 np_val = np.random.rand(3, 4, 7).astype(np.float32) 1062 self.assertIs(np_val, tensor_util.constant_value(np_val)) 1063 1064 def testVariable(self): 1065 var = variables.Variable(1.0, name="variable_node") 1066 self.assertIsNone(tensor_util.constant_value(var)) 1067 1068 def testVariableV1(self): 1069 var = variables.VariableV1(1.0, name="variable_node") 1070 self.assertIsNone(tensor_util.constant_value(var)) 1071 1072 1073class ConstantValueAsShapeTest(test.TestCase): 1074 1075 @test_util.run_in_graph_and_eager_modes 1076 def testConstant(self): 1077 np_val = np.random.rand(3).astype(np.int32) 1078 tf_val = constant_op.constant(np_val) 1079 self.assertEqual( 1080 tensor_shape.TensorShape(np_val), 1081 tensor_util.constant_value_as_shape(tf_val)) 1082 1083 tf_val = constant_op.constant([], dtype=dtypes.int32) 1084 self.assertEqual( 1085 tensor_shape.TensorShape([]), 1086 tensor_util.constant_value_as_shape(tf_val)) 1087 1088 @test_util.run_in_graph_and_eager_modes 1089 def testCast(self): 1090 tf_val = math_ops.cast( 1091 array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])), 1092 dtypes.int64) 1093 c_val = tensor_util.constant_value_as_shape(tf_val) 1094 self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) 1095 1096 @test_util.run_in_graph_and_eager_modes 1097 def testCastWithUnknown(self): 1098 tf_val = math_ops.cast(constant_op.constant([-1, 1, -1]), dtypes.int64) 1099 c_val = tensor_util.constant_value_as_shape(tf_val) 1100 self.assertEqual([None, 1, None], c_val.as_list()) 1101 1102 @test_util.run_in_graph_and_eager_modes 1103 def testShape(self): 1104 tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) 1105 c_val = tensor_util.constant_value_as_shape(tf_val) 1106 self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) 1107 1108 @test_util.run_in_graph_and_eager_modes 1109 def testMinusOneBecomesNone(self): 1110 tf_val = constant_op.constant([-1, 1, -1], shape=[3]) 1111 c_val = tensor_util.constant_value_as_shape(tf_val) 1112 self.assertEqual([None, 1, None], c_val.as_list()) 1113 1114 def testPack(self): 1115 # This test needs a placeholder which means we need to construct a graph. 1116 with ops.Graph().as_default(): 1117 tf_val = array_ops.stack( 1118 [constant_op.constant(16), 37, 1119 array_ops.placeholder(dtypes.int32)]) 1120 c_val = tensor_util.constant_value_as_shape(tf_val) 1121 self.assertEqual([16, 37, None], c_val.as_list()) 1122 1123 def testConcat(self): 1124 # This test needs a placeholder which means we need to construct a graph. 1125 with ops.Graph().as_default(): 1126 tf_val = array_ops.concat( 1127 [[16, 37], array_ops.placeholder(dtypes.int32, shape=(2,))], 0) 1128 c_val = tensor_util.constant_value_as_shape(tf_val) 1129 self.assertEqual([16, 37, None, None], c_val.as_list()) 1130 1131 tf_val = array_ops.concat( 1132 [[16, 37], 1133 array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0) 1134 c_val = tensor_util.constant_value_as_shape(tf_val) 1135 self.assertEqual([16, 37, None, 48], c_val.as_list()) 1136 1137 def testSlice(self): 1138 # This test needs a placeholder which means we need to construct a graph. 1139 with ops.Graph().as_default(): 1140 tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2] 1141 c_val = tensor_util.constant_value_as_shape(tf_val) 1142 self.assertEqual([None, None], c_val.as_list()) 1143 1144 # begin:end 1145 tf_val = constant_op.constant([10, 20, 30])[1:3] 1146 c_val = tensor_util.constant_value_as_shape(tf_val) 1147 self.assertEqual([20, 30], c_val.as_list()) 1148 1149 # begin:end:stride 1150 tf_val = array_ops.strided_slice( 1151 constant_op.constant([10, 20, 30]), [1], [3], strides=[2]) 1152 c_val = tensor_util.constant_value_as_shape(tf_val) 1153 self.assertEqual([20], c_val.as_list()) 1154 1155 # [1, 2, 16, 37, None, 48] 1156 # This test needs a placeholder which means we need to construct a graph. 1157 with ops.Graph().as_default(): 1158 tf_val_orig = array_ops.concat( 1159 [[1, 2, 16, 37], 1160 array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0) 1161 1162 # begin: no end 1163 tf_val = tf_val_orig[2:] 1164 c_val = tensor_util.constant_value_as_shape(tf_val) 1165 self.assertEqual([16, 37, None, 48], c_val.as_list()) 1166 1167 # begin::negative slice 1168 tf_val = tf_val_orig[2::-1] 1169 c_val = tensor_util.constant_value_as_shape(tf_val) 1170 self.assertEqual([16, 2, 1], c_val.as_list()) 1171 1172 # :end:negative slice 1173 tf_val = tf_val_orig[:1:-2] 1174 c_val = tensor_util.constant_value_as_shape(tf_val) 1175 self.assertEqual([48, 37], c_val.as_list()) 1176 1177 # begin:end:negative slice 1178 tf_val = tf_val_orig[3:1:-1] 1179 c_val = tensor_util.constant_value_as_shape(tf_val) 1180 self.assertEqual([37, 16], c_val.as_list()) 1181 1182 # begin:negative end:slice 1183 tf_val = tf_val_orig[1:-3:1] 1184 c_val = tensor_util.constant_value_as_shape(tf_val) 1185 self.assertEqual([2, 16], c_val.as_list()) 1186 1187 # negative begin::slice 1188 tf_val = tf_val_orig[-3::1] 1189 c_val = tensor_util.constant_value_as_shape(tf_val) 1190 self.assertEqual([37, None, 48], c_val.as_list()) 1191 1192 # negative begin::negative slice 1193 tf_val = tf_val_orig[-3::-1] 1194 c_val = tensor_util.constant_value_as_shape(tf_val) 1195 self.assertEqual([37, 16, 2, 1], c_val.as_list()) 1196 1197 # negative begin:negative end:negative slice 1198 tf_val = tf_val_orig[-3:-5:-1] 1199 c_val = tensor_util.constant_value_as_shape(tf_val) 1200 self.assertEqual([37, 16], c_val.as_list()) 1201 1202 # Do not support shape inference for additional arguments 1203 tf_val = constant_op.constant([10, 20, 30])[...] 1204 c_val = tensor_util.constant_value_as_shape(tf_val) 1205 self.assertEqual([None, None, None], c_val.as_list()) 1206 1207 # Do not support shape inference for tensor slices. 1208 tf_val = constant_op.constant( 1209 [10, 20, 30])[array_ops.placeholder(dtypes.int32, shape=()):] 1210 c_val = tensor_util.constant_value_as_shape(tf_val) 1211 self.assertEqual(tensor_shape.unknown_shape(), c_val) 1212 1213 # Do not support shape inference for higher rank 1214 with self.assertRaises(ValueError): 1215 tf_val = constant_op.constant([[10], [20], [30]])[:, 0:] 1216 c_val = tensor_util.constant_value_as_shape(tf_val) 1217 1218 1219class MaybeSetStaticShapeTest(test.TestCase): 1220 1221 @contextlib.contextmanager 1222 def disableSetStaticShape(self): 1223 flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE 1224 tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False 1225 try: 1226 yield 1227 finally: 1228 tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old 1229 1230 def testMaybeSetStaticShape(self): 1231 shape = constant_op.constant([2, 5], dtype=dtypes.int32) 1232 1233 def reshape(): 1234 v = array_ops.zeros([10]) 1235 return array_ops.reshape(v, shape) 1236 # This test needs a placeholder which means we need to construct a graph. 1237 with ops.Graph().as_default(): 1238 with self.disableSetStaticShape(): 1239 graph_without_shape_propagation = func_graph.func_graph_from_py_func( 1240 "without_shape_propagation", reshape, [], {}) 1241 graph_with_shape_propagation = func_graph.func_graph_from_py_func( 1242 "with_shape_propagation", reshape, [], {}) 1243 self.assertCountEqual( 1244 [op.type for op in graph_without_shape_propagation.get_operations()], 1245 [op.type for op in graph_with_shape_propagation.get_operations()]) 1246 1247 def testMaybeSetStaticShapeScalarShape(self): 1248 1249 def reshape(): 1250 v = array_ops.placeholder(dtypes.float32) 1251 t = array_ops.reshape(v, [-1]) 1252 return t 1253 1254 with self.disableSetStaticShape(): 1255 graph_without_shape_propagation = func_graph.func_graph_from_py_func( 1256 "without_shape_propagation", reshape, [], {}) 1257 graph_with_shape_propagation = func_graph.func_graph_from_py_func( 1258 "with_shape_propagation", reshape, [], {}) 1259 self.assertCountEqual( 1260 [op.type for op in graph_without_shape_propagation.get_operations()], 1261 [op.type for op in graph_with_shape_propagation.get_operations()]) 1262 1263 1264class ShapeTensorTest(test_util.TensorFlowTestCase): 1265 1266 @test_util.run_in_graph_and_eager_modes 1267 def testConversion(self): 1268 """Make sure fully known TensorShape objects convert to Tensors.""" 1269 shape = tensor_shape.TensorShape([1, tensor_shape.Dimension(2)]) 1270 shape_tensor = tensor_util.shape_tensor(shape) 1271 self.assertAllEqual((1, 2), shape_tensor) 1272 1273 1274if __name__ == "__main__": 1275 test.main() 1276