1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to create TensorProtos.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import six 22 23from tensorflow.core.framework import tensor_pb2 24from tensorflow.core.framework import tensor_shape_pb2 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors_impl 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.types import core 31from tensorflow.python.types import internal 32from tensorflow.python.util import compat 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36# Fallback in case fast_tensor_util is not properly compiled. 37# pylint: disable=g-import-not-at-top 38try: 39 from tensorflow.python.framework import fast_tensor_util 40 _FAST_TENSOR_UTIL_AVAILABLE = True 41except ImportError: 42 _FAST_TENSOR_UTIL_AVAILABLE = False 43# pylint: enable=g-import-not-at-top 44 45 46def ExtractBitsFromFloat16(x): 47 return np.asarray(x, dtype=np.float16).view(np.uint16).item() 48 49 50def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 51 tensor_proto.half_val.extend( 52 [ExtractBitsFromFloat16(x) for x in proto_values]) 53 54 55def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 56 # TODO: Remove the conversion if cython supports np.float16_t 57 fast_tensor_util.AppendFloat16ArrayToTensorProto( 58 tensor_proto, 59 np.asarray(proto_values, dtype=np.float16).view(np.uint16)) 60 61 62def ExtractBitsFromBFloat16(x): 63 return np.asarray( 64 x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item() 65 66 67def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 68 tensor_proto.half_val.extend( 69 [ExtractBitsFromBFloat16(x) for x in proto_values]) 70 71 72def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 73 fast_tensor_util.AppendBFloat16ArrayToTensorProto( 74 tensor_proto, np.asarray( 75 proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) 76 77 78if _FAST_TENSOR_UTIL_AVAILABLE: 79 _NP_TO_APPEND_FN = { 80 dtypes.bfloat16.as_numpy_dtype: 81 FastAppendBFloat16ArrayToTensorProto, 82 np.float16: 83 _MediumAppendFloat16ArrayToTensorProto, 84 np.float32: 85 fast_tensor_util.AppendFloat32ArrayToTensorProto, 86 np.float64: 87 fast_tensor_util.AppendFloat64ArrayToTensorProto, 88 np.int32: 89 fast_tensor_util.AppendInt32ArrayToTensorProto, 90 np.int64: 91 fast_tensor_util.AppendInt64ArrayToTensorProto, 92 np.uint8: 93 fast_tensor_util.AppendUInt8ArrayToTensorProto, 94 np.uint16: 95 fast_tensor_util.AppendUInt16ArrayToTensorProto, 96 np.uint32: 97 fast_tensor_util.AppendUInt32ArrayToTensorProto, 98 np.uint64: 99 fast_tensor_util.AppendUInt64ArrayToTensorProto, 100 np.int8: 101 fast_tensor_util.AppendInt8ArrayToTensorProto, 102 np.int16: 103 fast_tensor_util.AppendInt16ArrayToTensorProto, 104 np.complex64: 105 fast_tensor_util.AppendComplex64ArrayToTensorProto, 106 np.complex128: 107 fast_tensor_util.AppendComplex128ArrayToTensorProto, 108 np.object: 109 fast_tensor_util.AppendObjectArrayToTensorProto, 110 np.bool: 111 fast_tensor_util.AppendBoolArrayToTensorProto, 112 dtypes.qint8.as_numpy_dtype: 113 fast_tensor_util.AppendInt8ArrayToTensorProto, 114 dtypes.quint8.as_numpy_dtype: 115 fast_tensor_util.AppendUInt8ArrayToTensorProto, 116 dtypes.qint16.as_numpy_dtype: 117 fast_tensor_util.AppendInt16ArrayToTensorProto, 118 dtypes.quint16.as_numpy_dtype: 119 fast_tensor_util.AppendUInt16ArrayToTensorProto, 120 dtypes.qint32.as_numpy_dtype: 121 fast_tensor_util.AppendInt32ArrayToTensorProto, 122 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 123 } 124else: 125 126 def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): 127 tensor_proto.float_val.extend([x.item() for x in proto_values]) 128 129 def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): 130 tensor_proto.double_val.extend([x.item() for x in proto_values]) 131 132 def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): 133 tensor_proto.int_val.extend([x.item() for x in proto_values]) 134 135 def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): 136 tensor_proto.int64_val.extend([x.item() for x in proto_values]) 137 138 def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values): 139 tensor_proto.int_val.extend([x.item()[0] for x in proto_values]) 140 141 def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values): 142 tensor_proto.uint32_val.extend([x.item() for x in proto_values]) 143 144 def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values): 145 tensor_proto.uint64_val.extend([x.item() for x in proto_values]) 146 147 def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): 148 tensor_proto.scomplex_val.extend( 149 [v.item() for x in proto_values for v in [x.real, x.imag]]) 150 151 def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): 152 tensor_proto.dcomplex_val.extend( 153 [v.item() for x in proto_values for v in [x.real, x.imag]]) 154 155 def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): 156 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) 157 158 def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): 159 tensor_proto.bool_val.extend([x.item() for x in proto_values]) 160 161 _NP_TO_APPEND_FN = { 162 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto, 163 np.float16: SlowAppendFloat16ArrayToTensorProto, 164 np.float32: SlowAppendFloat32ArrayToTensorProto, 165 np.float64: SlowAppendFloat64ArrayToTensorProto, 166 np.int32: SlowAppendIntArrayToTensorProto, 167 np.int64: SlowAppendInt64ArrayToTensorProto, 168 np.uint8: SlowAppendIntArrayToTensorProto, 169 np.uint16: SlowAppendIntArrayToTensorProto, 170 np.uint32: SlowAppendUInt32ArrayToTensorProto, 171 np.uint64: SlowAppendUInt64ArrayToTensorProto, 172 np.int8: SlowAppendIntArrayToTensorProto, 173 np.int16: SlowAppendIntArrayToTensorProto, 174 np.complex64: SlowAppendComplex64ArrayToTensorProto, 175 np.complex128: SlowAppendComplex128ArrayToTensorProto, 176 np.object: SlowAppendObjectArrayToTensorProto, 177 np.bool: SlowAppendBoolArrayToTensorProto, 178 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 179 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 180 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 181 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 182 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 183 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 184 } 185 186 187def GetFromNumpyDTypeDict(dtype_dict, dtype): 188 # NOTE: dtype_dict.get(dtype) always returns None. 189 for key, val in six.iteritems(dtype_dict): 190 if key == dtype: 191 return val 192 return None 193 194 195def GetNumpyAppendFn(dtype): 196 # numpy dtype for strings are variable length. We can not compare 197 # dtype with a single constant (np.string does not exist) to decide 198 # dtype is a "string" type. We need to compare the dtype.type to be 199 # sure it's a string type. 200 if dtype.type == np.string_ or dtype.type == np.unicode_: 201 if _FAST_TENSOR_UTIL_AVAILABLE: 202 return fast_tensor_util.AppendObjectArrayToTensorProto 203 else: 204 return SlowAppendObjectArrayToTensorProto 205 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) 206 207 208def TensorShapeProtoToList(shape): 209 """Convert a TensorShape to a list. 210 211 Args: 212 shape: A TensorShapeProto. 213 214 Returns: 215 List of integers representing the dimensions of the tensor. 216 """ 217 return [dim.size for dim in shape.dim] 218 219 220def _GetDenseDimensions(list_of_lists): 221 """Returns the inferred dense dimensions of a list of lists.""" 222 if not isinstance(list_of_lists, (list, tuple)): 223 return [] 224 elif not list_of_lists: 225 return [0] 226 else: 227 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) 228 229 230def _FlattenToStrings(nested_strings): 231 if isinstance(nested_strings, (list, tuple)): 232 for inner in nested_strings: 233 for flattened_string in _FlattenToStrings(inner): 234 yield flattened_string 235 else: 236 yield nested_strings 237 238 239_TENSOR_CONTENT_TYPES = frozenset([ 240 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, 241 dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, 242 dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64 243]) 244 245 246# pylint: disable=invalid-name 247def _check_failed(v): 248 # NB. none of the _check_* functions could raise a ValueError, so 249 # it is safe to use here. 250 raise ValueError(v) 251 252 253def _check_quantized(values): 254 # Cannot rely on `nest` because the leaves are tuples. 255 if not isinstance(values, (list, tuple)): 256 _check_failed(values) 257 if isinstance(values, tuple): 258 _ = [_check_int(v) for v in values] 259 else: 260 _ = [_check_quantized(v) for v in values] 261 262 263def _generate_isinstance_check(expected_types): 264 def inner(values): 265 for v in nest.flatten(values): 266 if not (isinstance(v, expected_types) or 267 (isinstance(v, np.ndarray) and 268 issubclass(v.dtype.type, expected_types))): 269 _check_failed(v) 270 271 return inner 272 273_check_int = _generate_isinstance_check( 274 (compat.integral_types, tensor_shape.Dimension)) 275_check_float = _generate_isinstance_check(compat.real_types) 276_check_complex = _generate_isinstance_check(compat.complex_types) 277_check_str = _generate_isinstance_check(compat.bytes_or_text_types) 278_check_bool = _generate_isinstance_check(bool) 279 280 281def _check_not_tensor(values): 282 _ = [_check_failed(v) for v in nest.flatten(values) 283 if isinstance(v, ops.Tensor)] 284# pylint: enable=invalid-name 285 286_TF_TO_IS_OK = { 287 dtypes.bool: _check_bool, 288 dtypes.complex128: _check_complex, 289 dtypes.complex64: _check_complex, 290 dtypes.float16: _check_float, 291 dtypes.float32: _check_float, 292 dtypes.float64: _check_float, 293 dtypes.int16: _check_int, 294 dtypes.int32: _check_int, 295 dtypes.int64: _check_int, 296 dtypes.int8: _check_int, 297 dtypes.qint16: _check_quantized, 298 dtypes.qint32: _check_quantized, 299 dtypes.qint8: _check_quantized, 300 dtypes.quint16: _check_quantized, 301 dtypes.quint8: _check_quantized, 302 dtypes.string: _check_str, 303 dtypes.uint16: _check_int, 304 dtypes.uint8: _check_int, 305 dtypes.uint32: _check_int, 306 dtypes.uint64: _check_int, 307} 308 309 310def _AssertCompatible(values, dtype): 311 if dtype is None: 312 fn = _check_not_tensor 313 else: 314 try: 315 fn = _TF_TO_IS_OK[dtype] 316 except KeyError: 317 # There isn't a specific fn, so we try to do the best possible. 318 if dtype.is_integer: 319 fn = _check_int 320 elif dtype.is_floating: 321 fn = _check_float 322 elif dtype.is_complex: 323 fn = _check_complex 324 elif dtype.is_quantized: 325 fn = _check_quantized 326 else: 327 fn = _check_not_tensor 328 329 try: 330 fn(values) 331 except ValueError as e: 332 [mismatch] = e.args 333 if dtype is None: 334 raise TypeError("Expected any non-tensor type, got a tensor instead.") 335 else: 336 raise TypeError("Expected %s, got %s of type '%s' instead." % 337 (dtype.name, repr(mismatch), type(mismatch).__name__)) 338 339 340def _is_array_like(obj): # pylint: disable=invalid-name 341 """Check if a given object is array-like.""" 342 if isinstance(obj, ops.Tensor) and not isinstance(obj, ops._EagerTensorBase): # pylint: disable=protected-access 343 # Tensor implements __array__ only so it can inform the user that it is not 344 # a valid array. 345 return False 346 347 # TODO(slebedev): an object could also implement C-level array interface. 348 if (callable(getattr(obj, "__array__", None)) or 349 isinstance(getattr(obj, "__array_interface__", None), dict)): 350 return True 351 352 try: 353 memoryview(obj) 354 except TypeError: 355 return False 356 else: 357 return not isinstance(obj, bytes) 358 359 360# pylint: disable=invalid-name 361@tf_export("make_tensor_proto") 362def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False, 363 allow_broadcast=False): 364 """Create a TensorProto. 365 366 In TensorFlow 2.0, representing tensors as protos should no longer be a 367 common workflow. That said, this utility function is still useful for 368 generating TF Serving request protos: 369 370 ```python 371 request = tensorflow_serving.apis.predict_pb2.PredictRequest() 372 request.model_spec.name = "my_model" 373 request.model_spec.signature_name = "serving_default" 374 request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new)) 375 ``` 376 377 `make_tensor_proto` accepts "values" of a python scalar, a python list, a 378 numpy ndarray, or a numpy scalar. 379 380 If "values" is a python scalar or a python list, make_tensor_proto 381 first convert it to numpy ndarray. If dtype is None, the 382 conversion tries its best to infer the right numpy data 383 type. Otherwise, the resulting numpy array has a compatible data 384 type with the given dtype. 385 386 In either case above, the numpy ndarray (either the caller provided 387 or the auto-converted) must have the compatible type with dtype. 388 389 `make_tensor_proto` then converts the numpy array to a tensor proto. 390 391 If "shape" is None, the resulting tensor proto represents the numpy 392 array precisely. 393 394 Otherwise, "shape" specifies the tensor's shape and the numpy array 395 can not have more elements than what "shape" specifies. 396 397 Args: 398 values: Values to put in the TensorProto. 399 dtype: Optional tensor_pb2 DataType value. 400 shape: List of integers representing the dimensions of tensor. 401 verify_shape: Boolean that enables verification of a shape of values. 402 allow_broadcast: Boolean that enables allowing scalars and 1 length vector 403 broadcasting. Cannot be true when verify_shape is true. 404 405 Returns: 406 A `TensorProto`. Depending on the type, it may contain data in the 407 "tensor_content" attribute, which is not directly useful to Python programs. 408 To access the values you should convert the proto back to a numpy ndarray 409 with `tf.make_ndarray(proto)`. 410 411 If `values` is a `TensorProto`, it is immediately returned; `dtype` and 412 `shape` are ignored. 413 414 Raises: 415 TypeError: if unsupported types are provided. 416 ValueError: if arguments have inappropriate values or if verify_shape is 417 True and shape of values is not equals to a shape from the argument. 418 419 """ 420 if allow_broadcast and verify_shape: 421 raise ValueError("allow_broadcast and verify_shape are not both allowed.") 422 if isinstance(values, tensor_pb2.TensorProto): 423 return values 424 425 if dtype: 426 dtype = dtypes.as_dtype(dtype) 427 428 is_quantized = ( 429 dtype in [ 430 dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16, 431 dtypes.qint32 432 ]) 433 434 if _is_array_like(values): 435 values = np.asarray(values) 436 437 # We first convert value to a numpy array or scalar. 438 if isinstance(values, (np.ndarray, np.generic)): 439 if dtype and dtype.is_numpy_compatible: 440 nparray = values.astype(dtype.as_numpy_dtype) 441 else: 442 nparray = values 443 else: 444 if values is None: 445 raise ValueError("None values not supported.") 446 # if dtype is provided, forces numpy array to be the type 447 # provided if possible. 448 if dtype and dtype.is_numpy_compatible: 449 np_dt = dtype.as_numpy_dtype 450 else: 451 np_dt = None 452 # If shape is None, numpy.prod returns None when dtype is not set, but 453 # raises exception when dtype is set to np.int64 454 if shape is not None and np.prod(shape, dtype=np.int64) == 0: 455 nparray = np.empty(shape, dtype=np_dt) 456 else: 457 _AssertCompatible(values, dtype) 458 nparray = np.array(values, dtype=np_dt) 459 # check to them. 460 # We need to pass in quantized values as tuples, so don't apply the shape 461 if (list(nparray.shape) != _GetDenseDimensions(values) and 462 not is_quantized): 463 raise ValueError("""Argument must be a dense tensor: %s""" 464 """ - got shape %s, but wanted %s.""" % 465 (values, list(nparray.shape), 466 _GetDenseDimensions(values))) 467 468 # python/numpy default float type is float64. We prefer float32 instead. 469 if (nparray.dtype == np.float64) and dtype is None: 470 nparray = nparray.astype(np.float32) 471 # python/numpy default int type is int64. We prefer int32 instead. 472 elif (nparray.dtype == np.int64) and dtype is None: 473 downcasted_array = nparray.astype(np.int32) 474 # Do not down cast if it leads to precision loss. 475 if np.array_equal(downcasted_array, nparray): 476 nparray = downcasted_array 477 478 # if dtype is provided, it must be compatible with what numpy 479 # conversion says. 480 numpy_dtype = dtypes.as_dtype(nparray.dtype) 481 if numpy_dtype is None: 482 raise TypeError("Unrecognized data type: %s" % nparray.dtype) 483 484 # If dtype was specified and is a quantized type, we convert 485 # numpy_dtype back into the quantized version. 486 if is_quantized: 487 numpy_dtype = dtype 488 489 if dtype is not None and (not hasattr(dtype, "base_dtype") or 490 dtype.base_dtype != numpy_dtype.base_dtype): 491 raise TypeError("Incompatible types: %s vs. %s. Value is %s" % 492 (dtype, nparray.dtype, values)) 493 494 # If shape is not given, get the shape from the numpy array. 495 if shape is None: 496 shape = nparray.shape 497 is_same_size = True 498 shape_size = nparray.size 499 else: 500 shape = [int(dim) for dim in shape] 501 shape_size = np.prod(shape, dtype=np.int64) 502 is_same_size = shape_size == nparray.size 503 504 if allow_broadcast: 505 if nparray.shape == (1,) or nparray.shape == tuple(): 506 pass 507 elif nparray.size != shape_size: 508 raise TypeError("Expected Tensor's shape: %s, got %s." % 509 (tuple(shape), nparray.shape)) 510 511 else: 512 if verify_shape and nparray.shape != tuple(shape): 513 raise TypeError("Expected Tensor's shape: %s, got %s." % 514 (tuple(shape), nparray.shape)) 515 516 if nparray.size > shape_size: 517 raise ValueError( 518 "Too many elements provided. Needed at most %d, but received %d" % 519 (shape_size, nparray.size)) 520 521 tensor_proto = tensor_pb2.TensorProto( 522 dtype=numpy_dtype.as_datatype_enum, 523 tensor_shape=tensor_shape.as_shape(shape).as_proto()) 524 525 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: 526 if nparray.size * nparray.itemsize >= (1 << 31): 527 raise ValueError( 528 "Cannot create a tensor proto whose content is larger than 2GB.") 529 tensor_proto.tensor_content = nparray.tobytes() 530 return tensor_proto 531 532 # If we were not given values as a numpy array, compute the proto_values 533 # from the given values directly, to avoid numpy trimming nulls from the 534 # strings. Since values could be a list of strings, or a multi-dimensional 535 # list of lists that might or might not correspond to the given shape, 536 # we flatten it conservatively. 537 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): 538 proto_values = _FlattenToStrings(values) 539 540 # At this point, values may be a list of objects that we could not 541 # identify a common type for (hence it was inferred as 542 # np.object/dtypes.string). If we are unable to convert it to a 543 # string, we raise a more helpful error message. 544 # 545 # Ideally, we'd be able to convert the elements of the list to a 546 # common type, but this type inference requires some thinking and 547 # so we defer it for now. 548 try: 549 str_values = [compat.as_bytes(x) for x in proto_values] 550 except TypeError: 551 raise TypeError("Failed to convert object of type %s to Tensor. " 552 "Contents: %s. Consider casting elements to a " 553 "supported type." % (type(values), values)) 554 tensor_proto.string_val.extend(str_values) 555 return tensor_proto 556 557 # TensorFlow expects C order (a.k.a., eigen row major). 558 proto_values = nparray.ravel() 559 560 append_fn = GetNumpyAppendFn(proto_values.dtype) 561 if append_fn is None: 562 raise TypeError( 563 "Element type not supported in TensorProto: %s" % numpy_dtype.name) 564 append_fn(tensor_proto, proto_values) 565 566 return tensor_proto 567# pylint: enable=invalid-name 568 569 570@tf_export("make_ndarray") 571def MakeNdarray(tensor): 572 """Create a numpy ndarray from a tensor. 573 574 Create a numpy ndarray with the same shape and data as the tensor. 575 576 For example: 577 578 ```python 579 # Tensor a has shape (2,3) 580 a = tf.constant([[1,2,3],[4,5,6]]) 581 proto_tensor = tf.make_tensor_proto(a) # convert `tensor a` to a proto tensor 582 tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3], 583 # [4, 5, 6]], dtype=int32) 584 # output has shape (2,3) 585 ``` 586 587 Args: 588 tensor: A TensorProto. 589 590 Returns: 591 A numpy array with the tensor contents. 592 593 Raises: 594 TypeError: if tensor has unsupported type. 595 596 """ 597 shape = [d.size for d in tensor.tensor_shape.dim] 598 num_elements = np.prod(shape, dtype=np.int64) 599 tensor_dtype = dtypes.as_dtype(tensor.dtype) 600 dtype = tensor_dtype.as_numpy_dtype 601 602 if tensor.tensor_content: 603 return (np.frombuffer(tensor.tensor_content, 604 dtype=dtype).copy().reshape(shape)) 605 606 if tensor_dtype == dtypes.string: 607 # np.pad throws on these arrays of type np.object. 608 values = list(tensor.string_val) 609 padding = num_elements - len(values) 610 if padding > 0: 611 last = values[-1] if values else "" 612 values.extend([last] * padding) 613 return np.array(values, dtype=dtype).reshape(shape) 614 615 if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: 616 # the half_val field of the TensorProto stores the binary representation 617 # of the fp16: we need to reinterpret this as a proper float16 618 values = np.fromiter(tensor.half_val, dtype=np.uint16) 619 values.dtype = tensor_dtype.as_numpy_dtype 620 elif tensor_dtype == dtypes.float32: 621 values = np.fromiter(tensor.float_val, dtype=dtype) 622 elif tensor_dtype == dtypes.float64: 623 values = np.fromiter(tensor.double_val, dtype=dtype) 624 elif tensor_dtype in [ 625 dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, 626 dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16 627 ]: 628 values = np.fromiter(tensor.int_val, dtype=dtype) 629 elif tensor_dtype == dtypes.int64: 630 values = np.fromiter(tensor.int64_val, dtype=dtype) 631 elif tensor_dtype == dtypes.uint32: 632 values = np.fromiter(tensor.uint32_val, dtype=dtype) 633 elif tensor_dtype == dtypes.uint64: 634 values = np.fromiter(tensor.uint64_val, dtype=dtype) 635 elif tensor_dtype == dtypes.complex64: 636 it = iter(tensor.scomplex_val) 637 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 638 elif tensor_dtype == dtypes.complex128: 639 it = iter(tensor.dcomplex_val) 640 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 641 elif tensor_dtype == dtypes.bool: 642 values = np.fromiter(tensor.bool_val, dtype=dtype) 643 else: 644 raise TypeError("Unsupported tensor type: %s" % tensor.dtype) 645 646 if values.size == 0: 647 return np.zeros(shape, dtype) 648 649 if values.size != num_elements: 650 values = np.pad(values, (0, num_elements - values.size), "edge") 651 652 return values.reshape(shape) 653 654 655def ShapeEquals(tensor_proto, shape): 656 """Returns True if "tensor_proto" has the given "shape". 657 658 Args: 659 tensor_proto: A TensorProto. 660 shape: A tensor shape, expressed as a TensorShape, list, or tuple. 661 662 Returns: 663 True if "tensor_proto" has the given "shape", otherwise False. 664 665 Raises: 666 TypeError: If "tensor_proto" is not a TensorProto, or shape is not a 667 TensorShape, list, or tuple. 668 """ 669 if not isinstance(tensor_proto, tensor_pb2.TensorProto): 670 raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object") 671 if isinstance(shape, tensor_shape_pb2.TensorShapeProto): 672 shape = [d.size for d in shape.dim] 673 elif not isinstance(shape, (list, tuple)): 674 raise TypeError("shape is not a list or tuple") 675 tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim] 676 return all(x == y for x, y in zip(tensor_shape_list, shape)) 677 678 679def _ConstantValue(tensor, partial): 680 # TODO(touts): Support Variables? 681 if not isinstance(tensor, ops.Tensor): 682 raise TypeError("%r is not a Tensor, has type %s" % (tensor, type(tensor))) 683 if tensor.op.type == "Const": 684 return MakeNdarray(tensor.op.get_attr("value")) 685 elif tensor.op.type == "Shape": 686 input_shape = tensor.op.inputs[0].get_shape() 687 if input_shape.is_fully_defined(): 688 return np.array( 689 [dim.value for dim in input_shape.dims], 690 dtype=tensor.dtype.as_numpy_dtype) 691 else: 692 return None 693 elif tensor.op.type == "Size": 694 input_shape = tensor.op.inputs[0].get_shape() 695 if input_shape.is_fully_defined(): 696 return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32) 697 else: 698 return None 699 elif tensor.op.type == "Rank": 700 input_shape = tensor.op.inputs[0].get_shape() 701 if input_shape.ndims is not None: 702 return np.ndarray( 703 shape=(), 704 buffer=np.array([input_shape.ndims], dtype=np.int32), 705 dtype=np.int32) 706 else: 707 return None 708 elif tensor.op.type == "Range": 709 start = constant_value(tensor.op.inputs[0]) 710 if start is None: 711 return None 712 limit = constant_value(tensor.op.inputs[1]) 713 if limit is None: 714 return None 715 delta = constant_value(tensor.op.inputs[2]) 716 if delta is None: 717 return None 718 return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) 719 elif tensor.op.type == "Cast": 720 pre_cast = constant_value(tensor.op.inputs[0]) 721 if pre_cast is None: 722 return None 723 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 724 return pre_cast.astype(cast_dtype.as_numpy_dtype) 725 elif tensor.op.type == "Concat": 726 dim = constant_value(tensor.op.inputs[0]) 727 if dim is None: 728 return None 729 values = [] 730 for x in tensor.op.inputs[1:]: 731 value = constant_value(x) 732 if value is None: 733 return None 734 values.append(value) 735 return np.concatenate(values, axis=dim) 736 elif tensor.op.type == "ConcatV2": 737 dim = constant_value(tensor.op.inputs[-1]) 738 if dim is None: 739 return None 740 values = [] 741 for x in tensor.op.inputs[:-1]: 742 value = constant_value(x) 743 if value is None: 744 return None 745 values.append(value) 746 return np.concatenate(values, axis=dim) 747 elif tensor.op.type == "Pack": 748 values = [] 749 # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid 750 # and shouldn't be produced, but to deal sensibly with them here we check 751 # and return None. 752 if not tensor.op.inputs: 753 return None 754 # We can't handle axis != 0 Packs at the moment. 755 if tensor.op.get_attr("axis") != 0: 756 return None 757 for x in tensor.op.inputs: 758 value = constant_value(x, partial) 759 if value is None and not partial: 760 return None 761 values.append(value) 762 return np.array(values) 763 elif tensor.op.type == "Unpack": 764 # We can't handle axis != 0 Unpacks at the moment. 765 if tensor.op.get_attr("axis") != 0: 766 return None 767 value = constant_value(tensor.op.inputs[0], partial) 768 if value is None: 769 return None 770 return value[tensor.value_index] 771 elif tensor.op.type == "Split": 772 dim = constant_value(tensor.op.inputs[0]) 773 value = constant_value(tensor.op.inputs[1], partial) 774 if value is None or dim is None: 775 return None 776 split = np.split(value, tensor.op.get_attr("num_split"), dim) 777 return split[tensor.value_index] 778 elif tensor.op.type == "Fill": 779 fill_shape = tensor.shape 780 fill_value = constant_value(tensor.op.inputs[1]) 781 if fill_shape.is_fully_defined() and fill_value is not None: 782 return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype) 783 else: 784 return None 785 elif tensor.op.type == "Equal": 786 value1 = constant_value(tensor.op.inputs[0]) 787 if value1 is None: 788 return None 789 value2 = constant_value(tensor.op.inputs[1]) 790 if value2 is None: 791 return None 792 return np.equal(value1, value2) 793 elif tensor.op.type == "NotEqual": 794 value1 = constant_value(tensor.op.inputs[0]) 795 if value1 is None: 796 return None 797 value2 = constant_value(tensor.op.inputs[1]) 798 if value2 is None: 799 return None 800 return np.not_equal(value1, value2) 801 elif tensor.op.type == "StopGradient": 802 return constant_value(tensor.op.inputs[0], partial) 803 elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"): 804 return constant_value(tensor.op.inputs[0], partial) 805 else: 806 return None 807 808 809@tf_export("get_static_value") 810def constant_value(tensor, partial=False): # pylint: disable=invalid-name 811 """Returns the constant value of the given tensor, if efficiently calculable. 812 813 This function attempts to partially evaluate the given tensor, and 814 returns its value as a numpy ndarray if this succeeds. 815 816 Example usage: 817 818 >>> a = tf.constant(10) 819 >>> tf.get_static_value(a) 820 10 821 >>> b = tf.constant(20) 822 >>> tf.get_static_value(tf.add(a, b)) 823 30 824 825 >>> # `tf.Variable` is not supported. 826 >>> c = tf.Variable(30) 827 >>> print(tf.get_static_value(c)) 828 None 829 830 Using `partial` option is most relevant when calling `get_static_value` inside 831 a `tf.function`. Setting it to `True` will return the results but for the 832 values that cannot be evaluated will be `None`. For example: 833 834 ```python 835 class Foo(object): 836 def __init__(self): 837 self.a = tf.Variable(1) 838 self.b = tf.constant(2) 839 840 @tf.function 841 def bar(self, partial): 842 packed = tf.raw_ops.Pack(values=[self.a, self.b]) 843 static_val = tf.get_static_value(packed, partial=partial) 844 tf.print(static_val) 845 846 f = Foo() 847 f.bar(partial=True) # `array([None, array(2, dtype=int32)], dtype=object)` 848 f.bar(partial=False) # `None` 849 ``` 850 851 Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it 852 will no longer be possible to feed a different value for `tensor`. This allows 853 the result of this function to influence the graph that is constructed, and 854 permits static shape optimizations. 855 856 Args: 857 tensor: The Tensor to be evaluated. 858 partial: If True, the returned numpy array is allowed to have partially 859 evaluated values. Values that can't be evaluated will be None. 860 861 Returns: 862 A numpy ndarray containing the constant value of the given `tensor`, 863 or None if it cannot be calculated. 864 865 Raises: 866 TypeError: if tensor is not an ops.Tensor. 867 """ 868 if isinstance(tensor, ops.EagerTensor): 869 try: 870 return tensor.numpy() 871 except errors_impl.UnimplementedError: 872 # Some EagerTensors may not implement .numpy/resolve, e.g. parallel 873 # tensors with multiple components on different devices. 874 return None 875 if not is_tensor(tensor): 876 return tensor 877 if not isinstance(tensor, ops.Tensor): 878 return None 879 ret = _ConstantValue(tensor, partial) 880 if ret is not None: 881 # The caller may now depend on the constant value of `tensor`, so we 882 # conservatively prevent it from being fed. 883 tensor.graph.prevent_feeding(tensor) 884 return ret 885 886 887def constant_value_as_shape(tensor): # pylint: disable=invalid-name 888 """A version of `constant_value()` that returns a `TensorShape`. 889 890 This version should be used when a constant tensor value is 891 interpreted as a (possibly partial) shape, e.g. in the shape 892 function for `tf.reshape()`. By explicitly requesting a 893 `TensorShape` as the return value, it is possible to represent 894 unknown dimensions; by contrast, `constant_value()` is 895 all-or-nothing. 896 897 Args: 898 tensor: The rank-0 or rank-1 Tensor to be evaluated. 899 900 Returns: 901 A `TensorShape` based on the constant value of the given `tensor`. 902 903 Raises: 904 ValueError: If the shape is rank-0 and is not statically known to be -1. 905 """ 906 if isinstance(tensor, ops.EagerTensor): 907 return tensor_shape.TensorShape( 908 [dim if dim != -1 else None for dim in tensor.numpy()]) 909 910 if tensor.get_shape().ndims == 0: 911 value = constant_value(tensor) 912 if value is None: 913 raise ValueError( 914 "Received a scalar with unknown value as shape; require a statically " 915 "known scalar with value '-1' to describe an unknown shape.") 916 if value != -1: 917 raise ValueError( 918 "Received a scalar value '%s' as shape; require a statically known " 919 "scalar with value '-1' to describe an unknown shape." % value) 920 return tensor_shape.unknown_shape() 921 922 shape = tensor.get_shape().with_rank(1) 923 if shape == [0]: 924 return tensor_shape.TensorShape([]) 925 elif tensor.op.type == "Cast": 926 pre_cast = constant_value_as_shape(tensor.op.inputs[0]) 927 if pre_cast.dims is None: 928 # the input to cast has a totally undefined shape; just return that. 929 return pre_cast 930 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 931 if cast_dtype not in (dtypes.int32, dtypes.int64): 932 return tensor_shape.unknown_shape(shape.dims[0].value) 933 dest_dtype_shape_array = np.array( 934 [x if x is not None else -1 for x in pre_cast.as_list()]).astype( 935 cast_dtype.as_numpy_dtype) 936 return tensor_shape.TensorShape([ 937 x if x >= 0 else None 938 for x in dest_dtype_shape_array]) 939 elif tensor.op.type == "Shape": 940 return tensor.op.inputs[0].get_shape() 941 elif tensor.op.type == "Pack": 942 ret = tensor_shape.TensorShape([]) # Empty list. 943 # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it 944 # would not be rank 1. 945 assert tensor.op.get_attr("axis") == 0 946 for pack_input in tensor.op.inputs: 947 # `pack_input` must be a scalar. Attempt to evaluate it, and append it 948 # to `ret`. 949 pack_input_val = constant_value(pack_input) 950 if pack_input_val is None or pack_input_val < 0: 951 new_dim = tensor_shape.Dimension(None) 952 else: 953 new_dim = tensor_shape.Dimension(pack_input_val) 954 ret = ret.concatenate([new_dim]) 955 return ret 956 elif tensor.op.type == "Concat": 957 # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is 958 # the only legal value when concatenating vectors, and it will 959 # have been checked by a previous shape function. 960 ret = tensor_shape.TensorShape([]) # Empty list. 961 for concat_input in tensor.op.inputs[1:]: 962 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 963 # and concatenate it with `ret`. 964 ret = ret.concatenate(constant_value_as_shape(concat_input)) 965 return ret 966 elif tensor.op.type == "ConcatV2": 967 # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is 968 # the only legal value when concatenating vectors, and it will 969 # have been checked by a previous shape function. 970 ret = tensor_shape.TensorShape([]) # Empty list. 971 for concat_input in tensor.op.inputs[:-1]: 972 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 973 # and concatenate it with `ret`. 974 ret = ret.concatenate(constant_value_as_shape(concat_input)) 975 return ret 976 elif tensor.op.type == "StridedSlice": 977 try: 978 begin = constant_value(tensor.op.inputs[1]) 979 end = constant_value(tensor.op.inputs[2]) 980 strides = constant_value(tensor.op.inputs[3]) 981 if begin is not None and end is not None and strides is not None: 982 begin = begin[0] 983 end = end[0] 984 strides = strides[0] 985 begin_mask = tensor.op.get_attr("begin_mask") 986 if begin_mask == 1: 987 begin = None 988 end_mask = tensor.op.get_attr("end_mask") 989 if end_mask == 1: 990 end = None 991 992 ellipsis_mask = tensor.op.get_attr("ellipsis_mask") 993 new_axis_mask = tensor.op.get_attr("new_axis_mask") 994 shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") 995 valid_attributes = (not ellipsis_mask and not new_axis_mask and 996 not shrink_axis_mask and (not begin_mask or 997 (begin_mask == 1)) and 998 (not end_mask or (end_mask == 1))) 999 if valid_attributes: # additional inputs not supported 1000 prev = constant_value_as_shape(tensor.op.inputs[0]) 1001 prev = prev[begin:end:strides] 1002 ret = tensor_shape.TensorShape(prev) 1003 return ret 1004 1005 except ValueError: # Could come from get_attr or slicing prev. 1006 pass 1007 except TypeError: # Could come from slicing prev. 1008 pass 1009 elif (tensor.op.type == "Placeholder" and 1010 tensor.op.graph.building_function and 1011 hasattr(tensor.op.graph, "internal_captures")): 1012 # If we are inside a FuncGraph try to lookup the constant value of the 1013 # corresponding external capture. Note that we only look at captures and 1014 # not the fed inputs because those can be fed different values in different 1015 # instantiations of the function call or different iterations of a 1016 # tf.while_loop. 1017 for i, capture in enumerate(tensor.op.graph.internal_captures): 1018 if capture is tensor: 1019 external_capture = tensor.op.graph.external_captures[i] 1020 return constant_value_as_shape(external_capture) 1021 1022 ret = tensor_shape.unknown_shape(shape.dims[0].value) 1023 value = constant_value(tensor) 1024 if value is not None: 1025 ret = ret.merge_with( 1026 tensor_shape.TensorShape([d if d >= 0 else None for d in value])) 1027 return ret 1028 1029 1030# TODO(mdan): Deprecate in favor of more static-friendly types. 1031@tf_export("is_tensor") 1032def is_tf_type(x): # pylint: disable=invalid-name 1033 """Checks whether `x` is a TF-native type that can be passed to many TF ops. 1034 1035 Use `is_tensor` to differentiate types that can ingested by TensorFlow ops 1036 without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and 1037 `tf.RaggedTensor`) from types that need to be converted into tensors before 1038 they are ingested (e.g., numpy `ndarray` and Python scalars). 1039 1040 For example, in the following code block: 1041 1042 ```python 1043 if not tf.is_tensor(t): 1044 t = tf.convert_to_tensor(t) 1045 return t.shape, t.dtype 1046 ``` 1047 1048 we check to make sure that `t` is a tensor (and convert it if not) before 1049 accessing its `shape` and `dtype`. (But note that not all TensorFlow native 1050 types have shapes or dtypes; `tf.data.Dataset` is an example of a TensorFlow 1051 native type that has neither shape nor dtype.) 1052 1053 Args: 1054 x: A python object to check. 1055 1056 Returns: 1057 `True` if `x` is a TensorFlow-native type. 1058 """ 1059 return (isinstance(x, internal.NativeObject) or 1060 isinstance(x, core.Tensor) or 1061 getattr(x, "is_tensor_like", False)) 1062 1063 1064# Deprecated alias for tensor_util.is_tf_type. 1065is_tensor = is_tf_type 1066 1067 1068def shape_tensor(shape): # pylint: disable=invalid-name 1069 """Convert to an int32 or int64 tensor, defaulting to int32 if empty.""" 1070 dtype = None 1071 if isinstance(shape, (tuple, list)): 1072 if not shape: 1073 dtype = dtypes.int32 1074 else: 1075 # If there are Dimension objects in the shape, unwrap them. This can be a 1076 # problem if v1 and v2 TensorShape objects get mixed up in partial 1077 # conversions, leading to shapes such as (1, 2, Dimension(5)), which are 1078 # not convertible to Tensors because of mixed content. 1079 shape = tuple(map(tensor_shape.dimension_value, shape)) 1080 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 1081 1082 1083# DO NOT USE: For testing only. 1084_ENABLE_MAYBE_SET_STATIC_SHAPE = True 1085 1086 1087def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name 1088 """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable. 1089 1090 This is a temporary workaround to fix shape inference across functional op 1091 boundaries. E.g. 1092 1093 ```python 1094 shape = tf.constant([3]) 1095 @tf.function 1096 def f(): 1097 u = tf.random_uniform(shape) 1098 return u 1099 ``` 1100 1101 If we were to rely solely on C++ shape inference, the shape of `u` inside 1102 `f` would be unknown because C++ shape inference is not aware of the outer 1103 graph and all it sees is a Placeholder node when backtracing the captured 1104 tensor for `shape`. `maybe_set_static_shape` computes the static shape value 1105 of `shape` by traversing the `FuncGraph` boundaries and sets the correct 1106 shape. 1107 1108 A longer term solution would be to fix C++ shape inference. 1109 1110 Args: 1111 tensor: A tensor. 1112 shape: A shape tensor. 1113 """ 1114 if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and 1115 ops.get_default_graph().building_function and 1116 not tensor.shape.is_fully_defined() and is_tensor(shape)): 1117 shape = shape_tensor(shape) 1118 const_shape = constant_value_as_shape(shape) 1119 tensor.set_shape(const_shape) 1120