1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for tensor_util.""" 16 17import contextlib 18import sys 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import func_graph 26from tensorflow.python.framework import indexed_slices 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import gen_state_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.ragged import ragged_factory_ops 37from tensorflow.python.platform import test 38 39 40@test_util.run_all_in_graph_and_eager_modes 41class TensorUtilTest(test.TestCase, parameterized.TestCase): 42 43 def testFloat(self): 44 value = 10.0 45 t = tensor_util.make_tensor_proto(value) 46 self.assertProtoEquals(""" 47 dtype: DT_FLOAT 48 tensor_shape {} 49 float_val: %.1f 50 """ % value, t) 51 a = tensor_util.MakeNdarray(t) 52 self.assertEqual(np.float32, a.dtype) 53 self.assertAllClose(np.array(value, dtype=np.float32), a) 54 55 def testFloatN(self): 56 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0]) 57 if sys.byteorder == "big": 58 self.assertProtoEquals(r""" 59 dtype: DT_FLOAT 60 tensor_shape { dim { size: 3 } } 61 tensor_content: "A \000\000A\240\000\000A\360\000\000" 62 """, t) 63 else: 64 self.assertProtoEquals(r""" 65 dtype: DT_FLOAT 66 tensor_shape { dim { size: 3 } } 67 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 68 """, t) 69 a = tensor_util.MakeNdarray(t) 70 self.assertEqual(np.float32, a.dtype) 71 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 72 73 def testFloatTyped(self): 74 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) 75 if sys.byteorder == "big": 76 self.assertProtoEquals(r""" 77 dtype: DT_FLOAT 78 tensor_shape { dim { size: 3 } } 79 tensor_content: "A \000\000A\240\000\000A\360\000\000" 80 """, t) 81 else: 82 self.assertProtoEquals(r""" 83 dtype: DT_FLOAT 84 tensor_shape { dim { size: 3 } } 85 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 86 """, t) 87 a = tensor_util.MakeNdarray(t) 88 self.assertEqual(np.float32, a.dtype) 89 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 90 91 def testFloatTypeCoerce(self): 92 t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32) 93 if sys.byteorder == "big": 94 self.assertProtoEquals(r""" 95 dtype: DT_FLOAT 96 tensor_shape { dim { size: 3 } } 97 tensor_content: "A \000\000A\240\000\000A\360\000\000" 98 """, t) 99 else: 100 self.assertProtoEquals(r""" 101 dtype: DT_FLOAT 102 tensor_shape { dim { size: 3 } } 103 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 104 """, t) 105 a = tensor_util.MakeNdarray(t) 106 self.assertEqual(np.float32, a.dtype) 107 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 108 109 def testFloatTypeCoerceNdarray(self): 110 arr = np.asarray([10, 20, 30], dtype="int") 111 t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32) 112 if sys.byteorder == "big": 113 self.assertProtoEquals(r""" 114 dtype: DT_FLOAT 115 tensor_shape { dim { size: 3 } } 116 tensor_content: "A \000\000A\240\000\000A\360\000\000" 117 """, t) 118 else: 119 self.assertProtoEquals(r""" 120 dtype: DT_FLOAT 121 tensor_shape { dim { size: 3 } } 122 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 123 """, t) 124 a = tensor_util.MakeNdarray(t) 125 self.assertEqual(np.float32, a.dtype) 126 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 127 128 def testFloatSizes(self): 129 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3]) 130 if sys.byteorder == "big": 131 self.assertProtoEquals(r""" 132 dtype: DT_FLOAT 133 tensor_shape { dim { size: 1 } dim { size: 3 } } 134 tensor_content: "A \000\000A\240\000\000A\360\000\000" 135 """, t) 136 else: 137 self.assertProtoEquals(r""" 138 dtype: DT_FLOAT 139 tensor_shape { dim { size: 1 } dim { size: 3 } } 140 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 141 """, t) 142 a = tensor_util.MakeNdarray(t) 143 self.assertEqual(np.float32, a.dtype) 144 self.assertAllClose(np.array([[10.0, 20.0, 30.0]], dtype=np.float32), a) 145 146 def testFloatSizes2(self): 147 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1]) 148 if sys.byteorder == "big": 149 self.assertProtoEquals(r""" 150 dtype: DT_FLOAT 151 tensor_shape { dim { size: 3 } dim { size: 1 } } 152 tensor_content: "A \000\000A\240\000\000A\360\000\000" 153 """, t) 154 else: 155 self.assertProtoEquals(r""" 156 dtype: DT_FLOAT 157 tensor_shape { dim { size: 3 } dim { size: 1 } } 158 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 159 """, t) 160 a = tensor_util.MakeNdarray(t) 161 self.assertEqual(np.float32, a.dtype) 162 self.assertAllClose(np.array([[10.0], [20.0], [30.0]], dtype=np.float32), a) 163 164 def testFloatSizesLessValues(self): 165 t = tensor_util.make_tensor_proto(10.0, shape=[1, 3]) 166 self.assertProtoEquals(""" 167 dtype: DT_FLOAT 168 tensor_shape { dim { size: 1 } dim { size: 3 } } 169 float_val: 10.0 170 """, t) 171 # No conversion to Ndarray for this one: not enough values. 172 173 def testFloatNpArrayFloat64(self): 174 t = tensor_util.make_tensor_proto( 175 np.array([[10.0, 20.0, 30.0]], dtype=np.float64)) 176 if sys.byteorder == "big": 177 self.assertProtoEquals(r""" 178 dtype: DT_DOUBLE 179 tensor_shape { dim { size: 1 } dim { size: 3 } } 180 tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000" 181 """, t) 182 else: 183 self.assertProtoEquals(r""" 184 dtype: DT_DOUBLE 185 tensor_shape { dim { size: 1 } dim { size: 3 } } 186 tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@" 187 """, t) 188 a = tensor_util.MakeNdarray(t) 189 self.assertEqual(np.float64, a.dtype) 190 self.assertAllClose( 191 np.array([[10.0, 20.0, 30.0]], dtype=np.float64), 192 tensor_util.MakeNdarray(t)) 193 194 def testFloatTypesWithImplicitRepeat(self): 195 for dtype, nptype in [(dtypes.float32, np.float32), 196 (dtypes.float64, np.float64)]: 197 t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype) 198 a = tensor_util.MakeNdarray(t) 199 self.assertAllClose( 200 np.array( 201 [[10.0, 10.0, 10.0, 10.0], 202 [10.0, 10.0, 10.0, 10.0], 203 [10.0, 10.0, 10.0, 10.0]], 204 dtype=nptype), 205 a) 206 207 def testFloatMutateArray(self): 208 t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) 209 a = tensor_util.MakeNdarray(t) 210 a[0] = 5.0 211 self.assertEqual(np.float32, a.dtype) 212 self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a) 213 if sys.byteorder == "big": 214 self.assertProtoEquals(r""" 215 dtype: DT_FLOAT 216 tensor_shape { dim { size: 3 } } 217 tensor_content: "A \000\000A\240\000\000A\360\000\000" 218 """, t) 219 else: 220 self.assertProtoEquals(r""" 221 dtype: DT_FLOAT 222 tensor_shape { dim { size: 3 } } 223 tensor_content: "\000\000 A\000\000\240A\000\000\360A" 224 """, t) 225 226 def testHalf(self): 227 t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16)) 228 self.assertProtoEquals( 229 """ 230 dtype: DT_HALF 231 tensor_shape { dim { size: 2 } } 232 tensor_content: "\000I\000M" 233 """, t) 234 235 a = tensor_util.MakeNdarray(t) 236 self.assertEqual(np.float16, a.dtype) 237 self.assertAllClose(np.array([10.0, 20.0], dtype=np.float16), a) 238 239 def testBfloat16(self): 240 test_type = dtypes.bfloat16.as_numpy_dtype 241 t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=test_type)) 242 # 10.0: 16672 = 010000010(130) 0100000: (1+0/2+1/4) * 2^(130-127) 243 # 20.0: 16800 = 010000011(131) 0100000: (1+0/2+1/4) * 2^(131-127) 244 self.assertProtoEquals(""" 245 dtype: DT_BFLOAT16 246 tensor_shape { 247 dim { 248 size: 2 249 } 250 } 251 half_val: 16672 252 half_val: 16800 253 """, t) 254 255 a = tensor_util.MakeNdarray(t) 256 self.assertEqual(test_type, a.dtype) 257 self.assertAllClose(np.array([10.0, 20.0], dtype=test_type), a) 258 259 def testInt(self): 260 t = tensor_util.make_tensor_proto(10) 261 self.assertProtoEquals(""" 262 dtype: DT_INT32 263 tensor_shape {} 264 int_val: 10 265 """, t) 266 a = tensor_util.MakeNdarray(t) 267 self.assertEqual(np.int32, a.dtype) 268 self.assertAllClose(np.array(10, dtype=np.int32), a) 269 270 def testLargeInt(self): 271 value = np.iinfo(np.int64).max 272 t = tensor_util.make_tensor_proto(value) 273 self.assertProtoEquals(""" 274 dtype: DT_INT64 275 tensor_shape {} 276 int64_val: %d 277 """ % value, t) 278 a = tensor_util.MakeNdarray(t) 279 self.assertEqual(np.int64, a.dtype) 280 self.assertAllClose(np.array(value, dtype=np.int64), a) 281 282 def testLargeNegativeInt(self): 283 # We don't use the min np.int64 value here 284 # because it breaks np.abs(). 285 # 286 # np.iinfo(np.int64).min = -9223372036854775808 287 # np.iinfo(np.int64).max = 9223372036854775807 288 # np.abs(-9223372036854775808) = -9223372036854775808 289 value = np.iinfo(np.int64).min + 1 290 t = tensor_util.make_tensor_proto(value) 291 self.assertProtoEquals(""" 292 dtype: DT_INT64 293 tensor_shape {} 294 int64_val: %d 295 """ % value, t) 296 a = tensor_util.MakeNdarray(t) 297 self.assertEqual(np.int64, a.dtype) 298 self.assertAllClose(np.array(value, dtype=np.int64), a) 299 300 def testIntNDefaultType(self): 301 t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) 302 if sys.byteorder == "big": 303 self.assertProtoEquals(r""" 304 dtype: DT_INT32 305 tensor_shape { dim { size: 2 } dim { size: 2 } } 306 tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000(" 307 """, t) 308 else: 309 self.assertProtoEquals(r""" 310 dtype: DT_INT32 311 tensor_shape { dim { size: 2 } dim { size: 2 } } 312 tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000" 313 """, t) 314 a = tensor_util.MakeNdarray(t) 315 self.assertEqual(np.int32, a.dtype) 316 self.assertAllClose(np.array([[10, 20], [30, 40]], dtype=np.int32), a) 317 318 @parameterized.named_parameters( 319 ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16), 320 ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64), 321 ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16), 322 ("_uint32", dtypes.uint32, np.uint32), 323 ("_uint64", dtypes.uint64, np.uint64)) 324 def testIntTypes(self, dtype, nptype): 325 # Test with array. 326 t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype) 327 self.assertEqual(dtype, t.dtype) 328 self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) 329 a = tensor_util.MakeNdarray(t) 330 self.assertEqual(nptype, a.dtype) 331 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 332 # Test with ndarray. 333 t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype)) 334 self.assertEqual(dtype, t.dtype) 335 self.assertProtoEquals("dim { size: 3 }", t.tensor_shape) 336 a = tensor_util.MakeNdarray(t) 337 self.assertEqual(nptype, a.dtype) 338 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 339 340 @parameterized.named_parameters( 341 ("_int8", dtypes.int8, np.int8), ("_int16", dtypes.int16, np.int16), 342 ("_int32", dtypes.int32, np.int32), ("_int64", dtypes.int64, np.int64), 343 ("_uint8", dtypes.uint8, np.uint8), ("_uint16", dtypes.uint16, np.uint16), 344 ("_uint32", dtypes.uint32, np.uint32), 345 ("_uint64", dtypes.uint64, np.uint64)) 346 def testIntTypesWithImplicitRepeat(self, dtype, nptype): 347 self.assertAllEqual( 348 np.array([[10, 11, 12, 12], [12, 12, 12, 12], [12, 12, 12, 12]], 349 dtype=nptype), 350 tensor_util.MakeNdarray( 351 tensor_util.make_tensor_proto([10, 11, 12], 352 shape=[3, 4], 353 dtype=dtype))) 354 355 def testIntMixedWithDimension(self): 356 # Github issue: 11974 357 dtype = dtypes.int32 358 nptype = np.int32 359 t = tensor_util.make_tensor_proto( 360 [10, tensor_shape.Dimension(20), 30], dtype=dtype) 361 self.assertEqual(dtype, t.dtype) 362 a = tensor_util.MakeNdarray(t) 363 self.assertEqual(nptype, a.dtype) 364 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 365 366 @parameterized.named_parameters( 367 ("_int64", dtypes.int64, np.int64, "DT_INT64", "int64_val"), 368 ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64", "uint64_val")) 369 def testLong(self, dtype, nptype, proto_dtype, proto_value_name): 370 t = tensor_util.make_tensor_proto(10, dtype=dtype) 371 self.assertProtoEquals( 372 """ 373 dtype: %s 374 tensor_shape {} 375 %s: 10 376 """ % (proto_dtype, proto_value_name), t) 377 a = tensor_util.MakeNdarray(t) 378 self.assertEqual(nptype, a.dtype) 379 self.assertAllClose(np.array(10, dtype=nptype), a) 380 381 @parameterized.named_parameters( 382 ("_int64", dtypes.int64, np.int64, "DT_INT64"), 383 ("_uint64", dtypes.uint64, np.uint64, "DT_UINT64")) 384 def testLongN(self, dtype, nptype, proto_dtype): 385 t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], dtype=dtype) 386 if sys.byteorder == "big": 387 # pylint: disable=line-too-long 388 self.assertProtoEquals( 389 r""" 390 dtype: %s 391 tensor_shape { dim { size: 1 } dim { size: 3 } } 392 tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" 393 """ % proto_dtype, t) 394 # pylint: enable=line-too-long 395 else: 396 # pylint: disable=line-too-long 397 self.assertProtoEquals( 398 r""" 399 dtype: %s 400 tensor_shape { dim { size: 1 } dim { size: 3 } } 401 tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" 402 """ % proto_dtype, t) 403 # pylint: enable=line-too-long 404 a = tensor_util.MakeNdarray(t) 405 self.assertEqual(nptype, a.dtype) 406 self.assertAllClose(np.array([[10, 20, 30]], dtype=nptype), a) 407 408 @parameterized.named_parameters(("_int64", np.int64, "DT_INT64"), 409 ("_uint64", np.uint64, "DT_UINT64")) 410 def testLongNpArray(self, nptype, proto_dtype): 411 t = tensor_util.make_tensor_proto(np.array([10, 20, 30], dtype=nptype)) 412 if sys.byteorder == "big": 413 # pylint: disable=line-too-long 414 self.assertProtoEquals( 415 r""" 416 dtype: %s 417 tensor_shape { dim { size: 3 } } 418 tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036" 419 """ % proto_dtype, t) 420 # pylint: enable=line-too-long 421 else: 422 # pylint: disable=line-too-long 423 self.assertProtoEquals( 424 r""" 425 dtype: %s 426 tensor_shape { dim { size: 3 } } 427 tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000" 428 """ % proto_dtype, t) 429 # pylint: enable=line-too-long 430 a = tensor_util.MakeNdarray(t) 431 self.assertEqual(nptype, a.dtype) 432 self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a) 433 434 def testQuantizedTypes(self): 435 # Test with array. 436 data = [(21,), (22,), (23,)] 437 438 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32) 439 if sys.byteorder == "big": 440 self.assertProtoEquals(r""" 441 dtype: DT_QINT32 442 tensor_shape { dim { size: 3 } } 443 tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027" 444 """, t) 445 else: 446 self.assertProtoEquals(r""" 447 dtype: DT_QINT32 448 tensor_shape { dim { size: 3 } } 449 tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000" 450 """, t) 451 a = tensor_util.MakeNdarray(t) 452 self.assertEqual(dtypes.qint32.as_numpy_dtype, a.dtype) 453 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 454 455 t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8) 456 self.assertProtoEquals(r""" 457 dtype: DT_QUINT8 458 tensor_shape { dim { size: 3 } } 459 tensor_content: "\025\026\027" 460 """, t) 461 a = tensor_util.MakeNdarray(t) 462 self.assertEqual(dtypes.quint8.as_numpy_dtype, a.dtype) 463 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 464 465 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8) 466 self.assertProtoEquals(r""" 467 dtype: DT_QINT8 468 tensor_shape { dim { size: 3 } } 469 tensor_content: "\025\026\027" 470 """, t) 471 a = tensor_util.MakeNdarray(t) 472 self.assertEqual(dtypes.qint8.as_numpy_dtype, a.dtype) 473 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 474 475 t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16) 476 if sys.byteorder == "big": 477 self.assertProtoEquals(r""" 478 dtype: DT_QUINT16 479 tensor_shape { dim { size: 3 } } 480 tensor_content: "\000\025\000\026\000\027" 481 """, t) 482 else: 483 self.assertProtoEquals(r""" 484 dtype: DT_QUINT16 485 tensor_shape { dim { size: 3 } } 486 tensor_content: "\025\000\026\000\027\000" 487 """, t) 488 a = tensor_util.MakeNdarray(t) 489 self.assertEqual(dtypes.quint16.as_numpy_dtype, a.dtype) 490 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 491 492 t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16) 493 if sys.byteorder == "big": 494 self.assertProtoEquals(r""" 495 dtype: DT_QINT16 496 tensor_shape { dim { size: 3 } } 497 tensor_content: "\000\025\000\026\000\027" 498 """, t) 499 else: 500 self.assertProtoEquals(r""" 501 dtype: DT_QINT16 502 tensor_shape { dim { size: 3 } } 503 tensor_content: "\025\000\026\000\027\000" 504 """, t) 505 a = tensor_util.MakeNdarray(t) 506 self.assertEqual(dtypes.qint16.as_numpy_dtype, a.dtype) 507 self.assertAllEqual(np.array(data, dtype=a.dtype), a) 508 509 def testString(self): 510 t = tensor_util.make_tensor_proto("foo") 511 self.assertProtoEquals(""" 512 dtype: DT_STRING 513 tensor_shape {} 514 string_val: "foo" 515 """, t) 516 a = tensor_util.MakeNdarray(t) 517 self.assertEqual(np.object_, a.dtype) 518 self.assertEqual([b"foo"], a) 519 520 def testStringWithImplicitRepeat(self): 521 t = tensor_util.make_tensor_proto(["f", "g"], shape=[3, 4]) 522 a = tensor_util.MakeNdarray(t) 523 self.assertAllEqual( 524 np.array([[b"f", b"g", b"g", b"g"], [b"g", b"g", b"g", b"g"], 525 [b"g", b"g", b"g", b"g"]], 526 dtype=np.object_), a) 527 528 def testStringN(self): 529 t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3]) 530 self.assertProtoEquals(""" 531 dtype: DT_STRING 532 tensor_shape { dim { size: 1 } dim { size: 3 } } 533 string_val: "foo" 534 string_val: "bar" 535 string_val: "baz" 536 """, t) 537 a = tensor_util.MakeNdarray(t) 538 self.assertEqual(np.object_, a.dtype) 539 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 540 541 def testStringNpArray(self): 542 t = tensor_util.make_tensor_proto( 543 np.array([[b"a", b"ab"], [b"abc", b"abcd"]])) 544 self.assertProtoEquals(""" 545 dtype: DT_STRING 546 tensor_shape { dim { size: 2 } dim { size: 2 } } 547 string_val: "a" 548 string_val: "ab" 549 string_val: "abc" 550 string_val: "abcd" 551 """, t) 552 a = tensor_util.MakeNdarray(t) 553 self.assertEqual(np.object_, a.dtype) 554 self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a) 555 556 def testArrayMethod(self): 557 558 class Wrapper(object): 559 560 def __array__(self, dtype=None): 561 del dtype 562 return np.array([b"foo", b"bar", b"baz"]) 563 564 t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3]) 565 self.assertProtoEquals(""" 566 dtype: DT_STRING 567 tensor_shape { dim { size: 1 } dim { size: 3 } } 568 string_val: "foo" 569 string_val: "bar" 570 string_val: "baz" 571 """, t) 572 a = tensor_util.MakeNdarray(t) 573 self.assertEqual(np.object_, a.dtype) 574 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 575 576 def testArrayInterface(self): 577 578 class Wrapper(object): 579 580 def __init__(self): 581 self.a = np.array([b"foo", b"bar", b"baz"]) 582 583 @property 584 def __array_interface__(self): 585 return self.a.__array_interface__ 586 587 t = tensor_util.make_tensor_proto(Wrapper(), shape=[1, 3]) 588 self.assertProtoEquals(""" 589 dtype: DT_STRING 590 tensor_shape { dim { size: 1 } dim { size: 3 } } 591 string_val: "foo" 592 string_val: "bar" 593 string_val: "baz" 594 """, t) 595 a = tensor_util.MakeNdarray(t) 596 self.assertEqual(np.object_, a.dtype) 597 self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) 598 599 def testStringTuple(self): 600 t = tensor_util.make_tensor_proto((b"a", b"ab", b"abc", b"abcd")) 601 self.assertProtoEquals(""" 602 dtype: DT_STRING 603 tensor_shape { dim { size: 4 } } 604 string_val: "a" 605 string_val: "ab" 606 string_val: "abc" 607 string_val: "abcd" 608 """, t) 609 a = tensor_util.MakeNdarray(t) 610 self.assertEqual(np.object_, a.dtype) 611 self.assertAllEqual(np.array((b"a", b"ab", b"abc", b"abcd")), a) 612 613 def testStringNestedTuple(self): 614 t = tensor_util.make_tensor_proto(((b"a", b"ab"), (b"abc", b"abcd"))) 615 self.assertProtoEquals(""" 616 dtype: DT_STRING 617 tensor_shape { dim { size: 2 } dim { size: 2 } } 618 string_val: "a" 619 string_val: "ab" 620 string_val: "abc" 621 string_val: "abcd" 622 """, t) 623 a = tensor_util.MakeNdarray(t) 624 self.assertEqual(np.object_, a.dtype) 625 self.assertAllEqual(np.array(((b"a", b"ab"), (b"abc", b"abcd"))), a) 626 627 def testComplex64(self): 628 t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex64) 629 self.assertProtoEquals(""" 630 dtype: DT_COMPLEX64 631 tensor_shape {} 632 scomplex_val: 1 633 scomplex_val: 2 634 """, t) 635 a = tensor_util.MakeNdarray(t) 636 self.assertEqual(np.complex64, a.dtype) 637 self.assertAllEqual(np.array(1 + 2j), a) 638 639 def testComplex128(self): 640 t = tensor_util.make_tensor_proto((1 + 2j), dtype=dtypes.complex128) 641 self.assertProtoEquals(""" 642 dtype: DT_COMPLEX128 643 tensor_shape {} 644 dcomplex_val: 1 645 dcomplex_val: 2 646 """, t) 647 a = tensor_util.MakeNdarray(t) 648 self.assertEqual(np.complex128, a.dtype) 649 self.assertAllEqual(np.array(1 + 2j), a) 650 651 def testComplexWithImplicitRepeat(self): 652 for dtype, np_dtype in [(dtypes.complex64, np.complex64), 653 (dtypes.complex128, np.complex128)]: 654 t = tensor_util.make_tensor_proto((1 + 1j), shape=[3, 4], dtype=dtype) 655 a = tensor_util.MakeNdarray(t) 656 self.assertAllClose( 657 np.array( 658 [[(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)], 659 [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)], 660 [(1 + 1j), (1 + 1j), (1 + 1j), (1 + 1j)]], 661 dtype=np_dtype), 662 a) 663 664 def testComplex64N(self): 665 t = tensor_util.make_tensor_proto( 666 [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex64) 667 self.assertProtoEquals(""" 668 dtype: DT_COMPLEX64 669 tensor_shape { dim { size: 1 } dim { size: 3 } } 670 scomplex_val: 1 671 scomplex_val: 2 672 scomplex_val: 3 673 scomplex_val: 4 674 scomplex_val: 5 675 scomplex_val: 6 676 """, t) 677 a = tensor_util.MakeNdarray(t) 678 self.assertEqual(np.complex64, a.dtype) 679 self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a) 680 681 def testComplex128N(self): 682 t = tensor_util.make_tensor_proto( 683 [(1 + 2j), (3 + 4j), (5 + 6j)], shape=[1, 3], dtype=dtypes.complex128) 684 self.assertProtoEquals(""" 685 dtype: DT_COMPLEX128 686 tensor_shape { dim { size: 1 } dim { size: 3 } } 687 dcomplex_val: 1 688 dcomplex_val: 2 689 dcomplex_val: 3 690 dcomplex_val: 4 691 dcomplex_val: 5 692 dcomplex_val: 6 693 """, t) 694 a = tensor_util.MakeNdarray(t) 695 self.assertEqual(np.complex128, a.dtype) 696 self.assertAllEqual(np.array([[(1 + 2j), (3 + 4j), (5 + 6j)]]), a) 697 698 def testComplex64NpArray(self): 699 t = tensor_util.make_tensor_proto( 700 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), 701 dtype=dtypes.complex64) 702 # scomplex_val are real_0, imag_0, real_1, imag_1, ... 703 self.assertProtoEquals(""" 704 dtype: DT_COMPLEX64 705 tensor_shape { dim { size: 2 } dim { size: 2 } } 706 scomplex_val: 1 707 scomplex_val: 2 708 scomplex_val: 3 709 scomplex_val: 4 710 scomplex_val: 5 711 scomplex_val: 6 712 scomplex_val: 7 713 scomplex_val: 8 714 """, t) 715 a = tensor_util.MakeNdarray(t) 716 self.assertEqual(np.complex64, a.dtype) 717 self.assertAllEqual( 718 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a) 719 720 def testComplex128NpArray(self): 721 t = tensor_util.make_tensor_proto( 722 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), 723 dtype=dtypes.complex128) 724 # scomplex_val are real_0, imag_0, real_1, imag_1, ... 725 self.assertProtoEquals(""" 726 dtype: DT_COMPLEX128 727 tensor_shape { dim { size: 2 } dim { size: 2 } } 728 dcomplex_val: 1 729 dcomplex_val: 2 730 dcomplex_val: 3 731 dcomplex_val: 4 732 dcomplex_val: 5 733 dcomplex_val: 6 734 dcomplex_val: 7 735 dcomplex_val: 8 736 """, t) 737 a = tensor_util.MakeNdarray(t) 738 self.assertEqual(np.complex128, a.dtype) 739 self.assertAllEqual( 740 np.array([[(1 + 2j), (3 + 4j)], [(5 + 6j), (7 + 8j)]]), a) 741 742 def testNestedNumpyArrayWithoutDType(self): 743 t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)]) 744 a = tensor_util.MakeNdarray(t) 745 self.assertEqual(np.float32, a.dtype) 746 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 747 748 def testNestedNumpyArrayWithDType(self): 749 t = tensor_util.make_tensor_proto([10.0, 20.0, np.array(30.0)], 750 dtype=dtypes.float32) 751 a = tensor_util.MakeNdarray(t) 752 self.assertEqual(np.float32, a.dtype) 753 self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) 754 755 def testUnsupportedDTypes(self): 756 with self.assertRaises(TypeError): 757 tensor_util.make_tensor_proto(np.array([1]), 0) 758 with self.assertRaises(TypeError): 759 tensor_util.make_tensor_proto(3, dtype=dtypes.qint8) 760 with self.assertRaises(TypeError): 761 tensor_util.make_tensor_proto([3], dtype=dtypes.qint8) 762 763 # Validate the helpful error message when trying to convert an 764 # unconvertible list as strings. 765 with self.assertRaisesRegex(TypeError, "Failed to convert elements"): 766 tensor_util.make_tensor_proto([tensor_shape.Dimension(1)]) 767 768 def testTensorShapeVerification(self): 769 array = np.array([[1], [2]]) 770 correct_shape = (2, 1) 771 incorrect_shape = (1, 2) 772 tensor_util.make_tensor_proto(array, shape=correct_shape, verify_shape=True) 773 with self.assertRaises(TypeError): 774 tensor_util.make_tensor_proto( 775 array, shape=incorrect_shape, verify_shape=True) 776 777 def testShapeTooLarge(self): 778 with self.assertRaises(ValueError): 779 tensor_util.make_tensor_proto(np.array([1, 2]), shape=[1]) 780 781 def testLowRankSupported(self): 782 t = tensor_util.make_tensor_proto(np.array(7)) 783 self.assertProtoEquals(""" 784 dtype: DT_INT64 785 tensor_shape {} 786 int64_val: 7 787 """, t) 788 789 def testShapeEquals(self): 790 t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2]) 791 self.assertTrue(tensor_util.ShapeEquals(t, [2, 2])) 792 self.assertTrue(tensor_util.ShapeEquals(t, (2, 2))) 793 self.assertTrue( 794 tensor_util.ShapeEquals(t, tensor_shape.as_shape([2, 2]).as_proto())) 795 self.assertFalse(tensor_util.ShapeEquals(t, [5, 3])) 796 self.assertFalse(tensor_util.ShapeEquals(t, [1, 4])) 797 self.assertFalse(tensor_util.ShapeEquals(t, [4])) 798 799 800@test_util.run_all_in_graph_and_eager_modes 801class IsTensorTest(test.TestCase): 802 803 def testConstantTensor(self): 804 np_val = np.random.rand(3).astype(np.int32) 805 tf_val = constant_op.constant(np_val) 806 self.assertFalse(tensor_util.is_tf_type(np_val)) 807 self.assertTrue(tensor_util.is_tf_type(tf_val)) 808 809 def testRaggedTensor(self): 810 rt = ragged_factory_ops.constant([[1, 2], [3]]) 811 rt_value = self.evaluate(rt) 812 self.assertTrue(tensor_util.is_tf_type(rt)) 813 self.assertFalse(tensor_util.is_tf_type(rt_value)) 814 815 def testSparseTensor(self): 816 st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10]) 817 st_value = self.evaluate(st) 818 self.assertTrue(tensor_util.is_tf_type(st)) 819 self.assertFalse(tensor_util.is_tf_type(st_value)) 820 821 def testIndexedSlices(self): 822 x = indexed_slices.IndexedSlices( 823 constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30])) 824 x_value = indexed_slices.IndexedSlicesValue( 825 np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100])) 826 self.assertTrue(tensor_util.is_tf_type(x)) 827 self.assertFalse(tensor_util.is_tf_type(x_value)) 828 829 def testVariable(self): 830 v = variables.Variable([1, 2, 3]) 831 self.assertTrue(tensor_util.is_tf_type(v)) 832 833 834class ConstantValueTest(test.TestCase): 835 836 def testConstant(self): 837 np_val = np.random.rand(3, 4, 7).astype(np.float32) 838 tf_val = constant_op.constant(np_val) 839 self.assertAllClose(np_val, tensor_util.constant_value(tf_val)) 840 841 np_val = np.random.rand(3, 0, 7).astype(np.float32) 842 tf_val = constant_op.constant(np_val) 843 self.assertAllClose(np_val, tensor_util.constant_value(tf_val)) 844 845 def testUnknown(self): 846 with ops.Graph().as_default(): 847 tf_val = gen_state_ops.variable( 848 shape=[3, 4, 7], 849 dtype=dtypes.float32, 850 name="tf_val", 851 container="", 852 shared_name="") 853 self.assertIs(None, tensor_util.constant_value(tf_val)) 854 855 def testShape(self): 856 np_val = np.array([1, 2, 3], dtype=np.int32) 857 tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) 858 c_val = tensor_util.constant_value(tf_val) 859 self.assertAllEqual(np_val, c_val) 860 self.assertEqual(np.int32, c_val.dtype) 861 862 def testFill(self): 863 np_val = np.array([-1, -1, -1], dtype=np.float32) 864 tf_val = array_ops.fill([3], constant_op.constant(-1.0)) 865 c_val = tensor_util.constant_value(tf_val) 866 self.assertAllEqual(np_val, c_val) 867 self.assertEqual(np.float32, c_val.dtype) 868 869 def testSize(self): 870 tf_val = array_ops.size(constant_op.constant(0.0, shape=[1, 2, 3])) 871 c_val = tensor_util.constant_value(tf_val) 872 self.assertEqual(6, c_val) 873 874 def testSizeOfScalar(self): 875 tf_val = array_ops.size(constant_op.constant(0.0)) 876 c_val = tensor_util.constant_value(tf_val) 877 self.assertEqual(1, c_val) 878 self.assertIn(type(c_val), [np.ndarray, np.int32]) 879 880 def testRank(self): 881 tf_val = array_ops.rank(constant_op.constant(0.0, shape=[1, 2, 3])) 882 c_val = tensor_util.constant_value(tf_val) 883 884 self.assertIn(type(c_val), [np.ndarray, np.int32]) 885 self.assertEqual((), c_val.shape) 886 self.assertEqual(3, c_val) 887 888 # Repeat test using array_ops.rank_internal to avoid the optimization that 889 # happens in the rank function. 890 tf_val = array_ops.rank_internal( 891 constant_op.constant( 892 0.0, shape=[1, 2, 3]), optimize=False) 893 c_val = tensor_util.constant_value(tf_val) 894 895 self.assertIn(type(c_val), [np.ndarray, np.int32]) 896 self.assertEqual((), c_val.shape) 897 self.assertEqual(3, c_val) 898 self.assertEqual([3], c_val) 899 900 def testCast(self): 901 np_val = np.random.rand(3, 4, 7).astype(np.float32) 902 tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) 903 c_val = tensor_util.constant_value(tf_val) 904 self.assertAllClose(np_val.astype(np.float64), c_val) 905 906 np_val = np.random.rand(3, 0, 7).astype(np.float32) 907 tf_val = math_ops.cast(constant_op.constant(np_val), dtypes.float64) 908 c_val = tensor_util.constant_value(tf_val) 909 self.assertAllClose(np_val.astype(np.float64), c_val) 910 911 def testConcat(self): 912 np_val = np.random.rand(3, 4, 7).astype(np.float32) 913 tf_val = array_ops.concat( 914 [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]], 0) 915 c_val = tensor_util.constant_value(tf_val) 916 self.assertAllClose(np_val, c_val) 917 918 # This test needs a placeholder which means we need to construct a graph. 919 with ops.Graph().as_default(): 920 tf_val = array_ops.concat( 921 [np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]], 922 array_ops.placeholder(dtypes.int32)) 923 c_val = tensor_util.constant_value(tf_val) 924 self.assertIs(None, c_val) 925 926 tf_val = array_ops.concat([ 927 np_val[0, :, :], 928 array_ops.placeholder(dtypes.float32), np_val[2, :, :] 929 ], 1) 930 c_val = tensor_util.constant_value(tf_val) 931 self.assertIs(None, c_val) 932 933 def testPack_Axis0(self): 934 inputs = [np.random.rand(4, 7) for _ in range(3)] 935 np_val = np.array(inputs) 936 tf_val = array_ops.stack(inputs) 937 c_val = tensor_util.constant_value(tf_val) 938 self.assertAllClose(np_val, c_val) 939 940 # This test needs a placeholder which means we need to construct a graph. 941 with ops.Graph().as_default(): 942 tf_val = array_ops.stack( 943 [inputs[0], 944 array_ops.placeholder(dtypes.float32), inputs[2]]) 945 c_val = tensor_util.constant_value(tf_val) 946 self.assertIs(None, c_val) 947 948 def testPack_Axis1(self): 949 # This test needs a placeholder which means we need to construct a graph. 950 with ops.Graph().as_default(): 951 inputs = [np.random.rand(4, 7) for _ in range(3)] 952 tf_val = array_ops.stack(inputs, axis=1) 953 c_val = tensor_util.constant_value(tf_val) 954 self.assertIsNone(c_val) 955 956 tf_val = array_ops.stack( 957 [inputs[0], 958 array_ops.placeholder(dtypes.float32), inputs[2]], axis=1) 959 c_val = tensor_util.constant_value(tf_val) 960 self.assertIs(None, c_val) 961 962 def testPack_Partial_Axis0(self): 963 input_ = np.random.rand(4, 7) 964 # This test needs a placeholder which means we need to construct a graph. 965 with ops.Graph().as_default(): 966 tf_val = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)]) 967 c_val = tensor_util.constant_value(tf_val, partial=True) 968 self.assertAllClose(input_, c_val[0]) 969 self.assertIsNone(c_val[1]) 970 971 def testPack_Partial_Axis1(self): 972 input_ = np.random.rand(4, 7) 973 # This test needs a placeholder which means we need to construct a graph. 974 with ops.Graph().as_default(): 975 tf_val = array_ops.stack( 976 [input_, array_ops.placeholder(dtypes.float32)], axis=1) 977 c_val = tensor_util.constant_value(tf_val, partial=True) 978 self.assertIsNone(c_val) 979 980 def testUnpack_Axis0(self): 981 inputs = np.random.rand(3, 4, 7) 982 tf_vals = array_ops.unstack(inputs) 983 c_vals = [tensor_util.constant_value(x) for x in tf_vals] 984 self.assertAllClose(inputs, c_vals) 985 986 def testUnpack_Partial_Axis0(self): 987 input_ = np.random.rand(4, 7) 988 # This test needs a placeholder which means we need to construct a graph. 989 with ops.Graph().as_default(): 990 packed = array_ops.stack([input_, array_ops.placeholder(dtypes.float32)]) 991 tf_vals = array_ops.unstack(packed) 992 c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals] 993 self.assertAllClose(input_, c_vals[0]) 994 self.assertIsNone(c_vals[1]) 995 996 def testSplit_Axis0(self): 997 inputs = np.random.rand(6, 5, 7) 998 tf_vals = array_ops.split(inputs, 3) 999 c_vals = [tensor_util.constant_value(x) for x in tf_vals] 1000 self.assertAllClose(np.split(inputs, 3), c_vals) 1001 1002 def testSplit_Partial_Axis0(self): 1003 input_ = np.random.rand(4, 7) 1004 # This test needs a placeholder which means we need to construct a graph. 1005 with ops.Graph().as_default(): 1006 placeholder = array_ops.placeholder(dtypes.float32, shape=(4, 7)) 1007 # it'd be better to use concat here, but concat doesn't support partial 1008 packed = array_ops.stack([input_, placeholder]) 1009 tf_vals = array_ops.split(packed, 2) 1010 c_vals = [tensor_util.constant_value(x, partial=True) for x in tf_vals] 1011 self.assertAllClose(input_, c_vals[0][0]) 1012 self.assertIsNone(c_vals[1][0]) 1013 1014 def testEqual(self): 1015 # Scalar inputs. 1016 tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(1)) 1017 self.assertEqual(tensor_util.constant_value(tf_val), True) 1018 1019 tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(0)) 1020 self.assertEqual(tensor_util.constant_value(tf_val), False) 1021 1022 # Shaped inputs with broadcast semantics. 1023 tf_val = math_ops.equal(constant_op.constant([[0, 1]]), 1024 constant_op.constant([[0], [1]])) 1025 c_val = tensor_util.constant_value(tf_val) 1026 self.assertAllEqual(c_val, [[True, False], [False, True]]) 1027 1028 def testNotEqual(self): 1029 # Scalar inputs. 1030 tf_val = math_ops.not_equal(constant_op.constant(1), 1031 constant_op.constant(1)) 1032 self.assertEqual(tensor_util.constant_value(tf_val), False) 1033 1034 tf_val = math_ops.not_equal(constant_op.constant(1), 1035 constant_op.constant(0)) 1036 self.assertEqual(tensor_util.constant_value(tf_val), True) 1037 1038 # Shaped inputs with broadcast semantics. 1039 tf_val = math_ops.not_equal(constant_op.constant([[0, 1]]), 1040 constant_op.constant([[0], [1]])) 1041 c_val = tensor_util.constant_value(tf_val) 1042 self.assertAllEqual(c_val, [[False, True], [True, False]]) 1043 1044 def testStopGradient(self): 1045 input_ = np.random.rand(4, 7) 1046 tf_val = array_ops.stop_gradient(input_) 1047 c_val = tensor_util.constant_value(tf_val) 1048 self.assertAllEqual(input_, c_val) 1049 1050 def testIdentity(self): 1051 input_ = np.random.rand(4, 7) 1052 tf_val = array_ops.identity(input_) 1053 c_val = tensor_util.constant_value(tf_val) 1054 self.assertAllEqual(input_, c_val) 1055 1056 def testLiteral(self): 1057 x = "hi" 1058 self.assertIs(x, tensor_util.constant_value(x)) 1059 1060 def testNumpyNdarray(self): 1061 np_val = np.random.rand(3, 4, 7).astype(np.float32) 1062 self.assertIs(np_val, tensor_util.constant_value(np_val)) 1063 1064 def testVariable(self): 1065 var = variables.Variable(1.0, name="variable_node") 1066 self.assertIsNone(tensor_util.constant_value(var)) 1067 1068 def testVariableV1(self): 1069 var = variables.VariableV1(1.0, name="variable_node") 1070 self.assertIsNone(tensor_util.constant_value(var)) 1071 1072 1073class ConstantValueAsShapeTest(test.TestCase): 1074 1075 @test_util.run_in_graph_and_eager_modes 1076 def testConstant(self): 1077 np_val = np.random.rand(3).astype(np.int32) 1078 tf_val = constant_op.constant(np_val) 1079 self.assertEqual( 1080 tensor_shape.TensorShape(np_val), 1081 tensor_util.constant_value_as_shape(tf_val)) 1082 1083 tf_val = constant_op.constant([], dtype=dtypes.int32) 1084 self.assertEqual( 1085 tensor_shape.TensorShape([]), 1086 tensor_util.constant_value_as_shape(tf_val)) 1087 1088 @test_util.run_in_graph_and_eager_modes 1089 def testCast(self): 1090 tf_val = math_ops.cast( 1091 array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])), 1092 dtypes.int64) 1093 c_val = tensor_util.constant_value_as_shape(tf_val) 1094 self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) 1095 1096 @test_util.run_in_graph_and_eager_modes 1097 def testCastWithUnknown(self): 1098 tf_val = math_ops.cast(constant_op.constant([-1, 1, -1]), dtypes.int64) 1099 c_val = tensor_util.constant_value_as_shape(tf_val) 1100 self.assertEqual([None, 1, None], c_val.as_list()) 1101 1102 @test_util.run_in_graph_and_eager_modes 1103 def testShape(self): 1104 tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3])) 1105 c_val = tensor_util.constant_value_as_shape(tf_val) 1106 self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val) 1107 1108 @test_util.run_in_graph_and_eager_modes 1109 def testMinusOneBecomesNone(self): 1110 tf_val = constant_op.constant([-1, 1, -1], shape=[3]) 1111 c_val = tensor_util.constant_value_as_shape(tf_val) 1112 self.assertEqual([None, 1, None], c_val.as_list()) 1113 1114 def testPack(self): 1115 # This test needs a placeholder which means we need to construct a graph. 1116 with ops.Graph().as_default(): 1117 tf_val = array_ops.stack( 1118 [constant_op.constant(16), 37, 1119 array_ops.placeholder(dtypes.int32)]) 1120 c_val = tensor_util.constant_value_as_shape(tf_val) 1121 self.assertEqual([16, 37, None], c_val.as_list()) 1122 1123 def testConcat(self): 1124 # This test needs a placeholder which means we need to construct a graph. 1125 with ops.Graph().as_default(): 1126 tf_val = array_ops.concat( 1127 [[16, 37], array_ops.placeholder(dtypes.int32, shape=(2,))], 0) 1128 c_val = tensor_util.constant_value_as_shape(tf_val) 1129 self.assertEqual([16, 37, None, None], c_val.as_list()) 1130 1131 tf_val = array_ops.concat( 1132 [[16, 37], 1133 array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0) 1134 c_val = tensor_util.constant_value_as_shape(tf_val) 1135 self.assertEqual([16, 37, None, 48], c_val.as_list()) 1136 1137 def testSlice(self): 1138 # This test needs a placeholder which means we need to construct a graph. 1139 with ops.Graph().as_default(): 1140 tf_val = array_ops.placeholder(dtypes.int32, shape=(4,))[0:2] 1141 c_val = tensor_util.constant_value_as_shape(tf_val) 1142 self.assertEqual([None, None], c_val.as_list()) 1143 1144 # begin:end 1145 tf_val = constant_op.constant([10, 20, 30])[1:3] 1146 c_val = tensor_util.constant_value_as_shape(tf_val) 1147 self.assertEqual([20, 30], c_val.as_list()) 1148 1149 # begin:end:stride 1150 tf_val = array_ops.strided_slice( 1151 constant_op.constant([10, 20, 30]), [1], [3], strides=[2]) 1152 c_val = tensor_util.constant_value_as_shape(tf_val) 1153 self.assertEqual([20], c_val.as_list()) 1154 1155 # [1, 2, 16, 37, None, 48] 1156 # This test needs a placeholder which means we need to construct a graph. 1157 with ops.Graph().as_default(): 1158 tf_val_orig = array_ops.concat( 1159 [[1, 2, 16, 37], 1160 array_ops.placeholder(dtypes.int32, shape=(1,)), [48]], 0) 1161 1162 # begin: no end 1163 tf_val = tf_val_orig[2:] 1164 c_val = tensor_util.constant_value_as_shape(tf_val) 1165 self.assertEqual([16, 37, None, 48], c_val.as_list()) 1166 1167 # begin::negative slice 1168 tf_val = tf_val_orig[2::-1] 1169 c_val = tensor_util.constant_value_as_shape(tf_val) 1170 self.assertEqual([16, 2, 1], c_val.as_list()) 1171 1172 # :end:negative slice 1173 tf_val = tf_val_orig[:1:-2] 1174 c_val = tensor_util.constant_value_as_shape(tf_val) 1175 self.assertEqual([48, 37], c_val.as_list()) 1176 1177 # begin:end:negative slice 1178 tf_val = tf_val_orig[3:1:-1] 1179 c_val = tensor_util.constant_value_as_shape(tf_val) 1180 self.assertEqual([37, 16], c_val.as_list()) 1181 1182 # begin:negative end:slice 1183 tf_val = tf_val_orig[1:-3:1] 1184 c_val = tensor_util.constant_value_as_shape(tf_val) 1185 self.assertEqual([2, 16], c_val.as_list()) 1186 1187 # negative begin::slice 1188 tf_val = tf_val_orig[-3::1] 1189 c_val = tensor_util.constant_value_as_shape(tf_val) 1190 self.assertEqual([37, None, 48], c_val.as_list()) 1191 1192 # negative begin::negative slice 1193 tf_val = tf_val_orig[-3::-1] 1194 c_val = tensor_util.constant_value_as_shape(tf_val) 1195 self.assertEqual([37, 16, 2, 1], c_val.as_list()) 1196 1197 # negative begin:negative end:negative slice 1198 tf_val = tf_val_orig[-3:-5:-1] 1199 c_val = tensor_util.constant_value_as_shape(tf_val) 1200 self.assertEqual([37, 16], c_val.as_list()) 1201 1202 # Do not support shape inference for additional arguments 1203 tf_val = constant_op.constant([10, 20, 30])[...] 1204 c_val = tensor_util.constant_value_as_shape(tf_val) 1205 self.assertEqual([None, None, None], c_val.as_list()) 1206 1207 # Do not support shape inference for tensor slices. 1208 tf_val = constant_op.constant( 1209 [10, 20, 30])[array_ops.placeholder(dtypes.int32, shape=()):] 1210 c_val = tensor_util.constant_value_as_shape(tf_val) 1211 self.assertEqual(tensor_shape.unknown_shape(), c_val) 1212 1213 # Do not support shape inference for higher rank 1214 with self.assertRaises(ValueError): 1215 tf_val = constant_op.constant([[10], [20], [30]])[:, 0:] 1216 c_val = tensor_util.constant_value_as_shape(tf_val) 1217 1218 1219class MaybeSetStaticShapeTest(test.TestCase): 1220 1221 @contextlib.contextmanager 1222 def disableSetStaticShape(self): 1223 flag_old = tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE 1224 tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = False 1225 try: 1226 yield 1227 finally: 1228 tensor_util._ENABLE_MAYBE_SET_STATIC_SHAPE = flag_old 1229 1230 def testMaybeSetStaticShape(self): 1231 shape = constant_op.constant([2, 5], dtype=dtypes.int32) 1232 1233 def reshape(): 1234 v = array_ops.zeros([10]) 1235 return array_ops.reshape(v, shape) 1236 # This test needs a placeholder which means we need to construct a graph. 1237 with ops.Graph().as_default(): 1238 with self.disableSetStaticShape(): 1239 graph_without_shape_propagation = func_graph.func_graph_from_py_func( 1240 "without_shape_propagation", reshape, [], {}) 1241 graph_with_shape_propagation = func_graph.func_graph_from_py_func( 1242 "with_shape_propagation", reshape, [], {}) 1243 self.assertCountEqual( 1244 [op.type for op in graph_without_shape_propagation.get_operations()], 1245 [op.type for op in graph_with_shape_propagation.get_operations()]) 1246 1247 def testMaybeSetStaticShapeScalarShape(self): 1248 1249 def reshape(): 1250 v = array_ops.placeholder(dtypes.float32) 1251 t = array_ops.reshape(v, [-1]) 1252 return t 1253 1254 with self.disableSetStaticShape(): 1255 graph_without_shape_propagation = func_graph.func_graph_from_py_func( 1256 "without_shape_propagation", reshape, [], {}) 1257 graph_with_shape_propagation = func_graph.func_graph_from_py_func( 1258 "with_shape_propagation", reshape, [], {}) 1259 self.assertCountEqual( 1260 [op.type for op in graph_without_shape_propagation.get_operations()], 1261 [op.type for op in graph_with_shape_propagation.get_operations()]) 1262 1263 1264class ShapeTensorTest(test_util.TensorFlowTestCase): 1265 1266 @test_util.run_in_graph_and_eager_modes 1267 def testConversion(self): 1268 """Make sure fully known TensorShape objects convert to Tensors.""" 1269 shape = tensor_shape.TensorShape([1, tensor_shape.Dimension(2)]) 1270 shape_tensor = tensor_util.shape_tensor(shape) 1271 self.assertAllEqual((1, 2), shape_tensor) 1272 1273 1274if __name__ == "__main__": 1275 test.main() 1276