1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to create TensorProtos.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21import six 22 23from tensorflow.core.framework import tensor_pb2 24from tensorflow.core.framework import tensor_shape_pb2 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.util import compat 29 30# Fallback in case fast_tensor_util is not properly compiled. 31# pylint: disable=g-import-not-at-top 32try: 33 from tensorflow.python.framework import fast_tensor_util 34 _FAST_TENSOR_UTIL_AVAILABLE = True 35except ImportError: 36 _FAST_TENSOR_UTIL_AVAILABLE = False 37 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import ops 40from tensorflow.python.util.tf_export import tf_export 41 42# pylint: enable=g-import-not-at-top 43 44 45def ExtractBitsFromFloat16(x): 46 return np.asarray(x, dtype=np.float16).view(np.uint16).item() 47 48 49def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 50 tensor_proto.half_val.extend( 51 [ExtractBitsFromFloat16(x) for x in proto_values]) 52 53 54def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values): 55 # TODO: Remove the conversion if cython supports np.float16_t 56 fast_tensor_util.AppendFloat16ArrayToTensorProto( 57 tensor_proto, 58 np.asarray(proto_values, dtype=np.float16).view(np.uint16)) 59 60 61def ExtractBitsFromBFloat16(x): 62 return np.asarray( 63 x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item() 64 65 66def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 67 tensor_proto.half_val.extend( 68 [ExtractBitsFromBFloat16(x) for x in proto_values]) 69 70 71def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): 72 fast_tensor_util.AppendBFloat16ArrayToTensorProto( 73 tensor_proto, np.asarray( 74 proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) 75 76 77if _FAST_TENSOR_UTIL_AVAILABLE: 78 _NP_TO_APPEND_FN = { 79 dtypes.bfloat16.as_numpy_dtype: 80 FastAppendBFloat16ArrayToTensorProto, 81 np.float16: 82 _MediumAppendFloat16ArrayToTensorProto, 83 np.float32: 84 fast_tensor_util.AppendFloat32ArrayToTensorProto, 85 np.float64: 86 fast_tensor_util.AppendFloat64ArrayToTensorProto, 87 np.int32: 88 fast_tensor_util.AppendInt32ArrayToTensorProto, 89 np.int64: 90 fast_tensor_util.AppendInt64ArrayToTensorProto, 91 np.uint8: 92 fast_tensor_util.AppendUInt8ArrayToTensorProto, 93 np.uint16: 94 fast_tensor_util.AppendUInt16ArrayToTensorProto, 95 np.uint32: 96 fast_tensor_util.AppendUInt32ArrayToTensorProto, 97 np.uint64: 98 fast_tensor_util.AppendUInt64ArrayToTensorProto, 99 np.int8: 100 fast_tensor_util.AppendInt8ArrayToTensorProto, 101 np.int16: 102 fast_tensor_util.AppendInt16ArrayToTensorProto, 103 np.complex64: 104 fast_tensor_util.AppendComplex64ArrayToTensorProto, 105 np.complex128: 106 fast_tensor_util.AppendComplex128ArrayToTensorProto, 107 np.object: 108 fast_tensor_util.AppendObjectArrayToTensorProto, 109 np.bool: 110 fast_tensor_util.AppendBoolArrayToTensorProto, 111 dtypes.qint8.as_numpy_dtype: 112 fast_tensor_util.AppendInt8ArrayToTensorProto, 113 dtypes.quint8.as_numpy_dtype: 114 fast_tensor_util.AppendUInt8ArrayToTensorProto, 115 dtypes.qint16.as_numpy_dtype: 116 fast_tensor_util.AppendInt8ArrayToTensorProto, 117 dtypes.quint16.as_numpy_dtype: 118 fast_tensor_util.AppendUInt8ArrayToTensorProto, 119 dtypes.qint32.as_numpy_dtype: 120 fast_tensor_util.AppendInt32ArrayToTensorProto, 121 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 122 } 123else: 124 125 def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values): 126 tensor_proto.float_val.extend([x.item() for x in proto_values]) 127 128 def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values): 129 tensor_proto.double_val.extend([x.item() for x in proto_values]) 130 131 def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values): 132 tensor_proto.int_val.extend([x.item() for x in proto_values]) 133 134 def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values): 135 tensor_proto.int64_val.extend([x.item() for x in proto_values]) 136 137 def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values): 138 tensor_proto.int_val.extend([x.item()[0] for x in proto_values]) 139 140 def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values): 141 tensor_proto.uint32_val.extend([x.item() for x in proto_values]) 142 143 def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values): 144 tensor_proto.uint64_val.extend([x.item() for x in proto_values]) 145 146 def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values): 147 tensor_proto.scomplex_val.extend( 148 [v.item() for x in proto_values for v in [x.real, x.imag]]) 149 150 def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values): 151 tensor_proto.dcomplex_val.extend( 152 [v.item() for x in proto_values for v in [x.real, x.imag]]) 153 154 def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): 155 tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) 156 157 def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): 158 tensor_proto.bool_val.extend([x.item() for x in proto_values]) 159 160 _NP_TO_APPEND_FN = { 161 dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto, 162 np.float16: SlowAppendFloat16ArrayToTensorProto, 163 np.float32: SlowAppendFloat32ArrayToTensorProto, 164 np.float64: SlowAppendFloat64ArrayToTensorProto, 165 np.int32: SlowAppendIntArrayToTensorProto, 166 np.int64: SlowAppendInt64ArrayToTensorProto, 167 np.uint8: SlowAppendIntArrayToTensorProto, 168 np.uint16: SlowAppendIntArrayToTensorProto, 169 np.uint32: SlowAppendUInt32ArrayToTensorProto, 170 np.uint64: SlowAppendUInt64ArrayToTensorProto, 171 np.int8: SlowAppendIntArrayToTensorProto, 172 np.int16: SlowAppendIntArrayToTensorProto, 173 np.complex64: SlowAppendComplex64ArrayToTensorProto, 174 np.complex128: SlowAppendComplex128ArrayToTensorProto, 175 np.object: SlowAppendObjectArrayToTensorProto, 176 np.bool: SlowAppendBoolArrayToTensorProto, 177 dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 178 dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 179 dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 180 dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 181 dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto, 182 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 183 } 184 185 186def GetFromNumpyDTypeDict(dtype_dict, dtype): 187 # NOTE: dtype_dict.get(dtype) always returns None. 188 for key, val in six.iteritems(dtype_dict): 189 if key == dtype: 190 return val 191 return None 192 193 194def GetNumpyAppendFn(dtype): 195 # numpy dtype for strings are variable length. We can not compare 196 # dtype with a single constant (np.string does not exist) to decide 197 # dtype is a "string" type. We need to compare the dtype.type to be 198 # sure it's a string type. 199 if dtype.type == np.string_ or dtype.type == np.unicode_: 200 if _FAST_TENSOR_UTIL_AVAILABLE: 201 return fast_tensor_util.AppendObjectArrayToTensorProto 202 else: 203 return SlowAppendObjectArrayToTensorProto 204 return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype) 205 206 207def TensorShapeProtoToList(shape): 208 """Convert a TensorShape to a list. 209 210 Args: 211 shape: A TensorShapeProto. 212 213 Returns: 214 List of integers representing the dimensions of the tensor. 215 """ 216 return [dim.size for dim in shape.dim] 217 218 219def _GetDenseDimensions(list_of_lists): 220 """Returns the inferred dense dimensions of a list of lists.""" 221 if not isinstance(list_of_lists, (list, tuple)): 222 return [] 223 elif not list_of_lists: 224 return [0] 225 else: 226 return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0]) 227 228 229def _FlattenToStrings(nested_strings): 230 if isinstance(nested_strings, (list, tuple)): 231 for inner in nested_strings: 232 for flattened_string in _FlattenToStrings(inner): 233 yield flattened_string 234 else: 235 yield nested_strings 236 237 238_TENSOR_CONTENT_TYPES = frozenset([ 239 dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16, 240 dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint16, 241 dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64 242]) 243 244 245class _Message(object): 246 247 def __init__(self, message): 248 self._message = message 249 250 def __repr__(self): 251 return self._message 252 253 254def _FirstNotNone(l): 255 for x in l: 256 if x is not None: 257 if isinstance(x, ops.Tensor): 258 return _Message("list containing Tensors") 259 else: 260 return x 261 return None 262 263 264def _NotNone(v): 265 if v is None: 266 return _Message("None") 267 else: 268 return v 269 270 271def _FilterTuple(v): 272 if not isinstance(v, (list, tuple)): 273 return v 274 if isinstance(v, tuple): 275 if not any(isinstance(x, (list, tuple)) for x in v): 276 return None 277 if isinstance(v, list): 278 if not any(isinstance(x, (list, tuple)) for x in v): 279 return _FirstNotNone( 280 [None if isinstance(x, (list, tuple)) else x for x in v]) 281 return _FirstNotNone([_FilterTuple(x) for x in v]) 282 283 284def _FilterInt(v): 285 if isinstance(v, (list, tuple)): 286 return _FirstNotNone([_FilterInt(x) for x in v]) 287 return None if isinstance( 288 v, (compat.integral_types, tensor_shape.Dimension)) else _NotNone(v) 289 290 291def _FilterFloat(v): 292 if isinstance(v, (list, tuple)): 293 return _FirstNotNone([_FilterFloat(x) for x in v]) 294 return None if isinstance(v, compat.real_types) else _NotNone(v) 295 296 297def _FilterComplex(v): 298 if isinstance(v, (list, tuple)): 299 return _FirstNotNone([_FilterComplex(x) for x in v]) 300 return None if isinstance(v, compat.complex_types) else _NotNone(v) 301 302 303def _FilterStr(v): 304 if isinstance(v, (list, tuple)): 305 return _FirstNotNone([_FilterStr(x) for x in v]) 306 if isinstance(v, compat.bytes_or_text_types): 307 return None 308 else: 309 return _NotNone(v) 310 311 312def _FilterBool(v): 313 if isinstance(v, (list, tuple)): 314 return _FirstNotNone([_FilterBool(x) for x in v]) 315 return None if isinstance(v, bool) else _NotNone(v) 316 317 318def _FilterNotTensor(v): 319 if isinstance(v, (list, tuple)): 320 return _FirstNotNone([_FilterNotTensor(x) for x in v]) 321 return str(v) if isinstance(v, ops.Tensor) else None 322 323 324_TF_TO_IS_OK = { 325 dtypes.bool: [_FilterBool], 326 dtypes.complex128: [_FilterComplex], 327 dtypes.complex64: [_FilterComplex], 328 dtypes.float16: [_FilterFloat], 329 dtypes.float32: [_FilterFloat], 330 dtypes.float64: [_FilterFloat], 331 dtypes.int16: [_FilterInt], 332 dtypes.int32: [_FilterInt], 333 dtypes.int64: [_FilterInt], 334 dtypes.int8: [_FilterInt], 335 dtypes.qint16: [_FilterInt, _FilterTuple], 336 dtypes.qint32: [_FilterInt, _FilterTuple], 337 dtypes.qint8: [_FilterInt, _FilterTuple], 338 dtypes.quint16: [_FilterInt, _FilterTuple], 339 dtypes.quint8: [_FilterInt, _FilterTuple], 340 dtypes.string: [_FilterStr], 341 dtypes.uint16: [_FilterInt], 342 dtypes.uint8: [_FilterInt], 343 dtypes.uint32: [_FilterInt], 344 dtypes.uint64: [_FilterInt], 345} 346 347 348def _AssertCompatible(values, dtype): 349 if dtype is None: 350 fn_list = [_FilterNotTensor] 351 else: 352 try: 353 fn_list = _TF_TO_IS_OK[dtype] 354 except KeyError: 355 # There isn't a specific fn_list, so we try to do the best possible. 356 if dtype.is_integer: 357 fn_list = [_FilterInt] 358 elif dtype.is_floating: 359 fn_list = [_FilterFloat] 360 elif dtype.is_complex: 361 fn_list = [_FilterComplex] 362 elif dtype.is_quantized: 363 fn_list = [_FilterInt, _FilterTuple] 364 else: 365 fn_list = [_FilterNotTensor] 366 mismatch = _FirstNotNone([fn(values) for fn in fn_list]) 367 if mismatch is not None: 368 if dtype is None: 369 raise TypeError("List of Tensors when single Tensor expected") 370 else: 371 raise TypeError("Expected %s, got %s of type '%s' instead." % 372 (dtype.name, repr(mismatch), type(mismatch).__name__)) 373 374 375# pylint: disable=invalid-name 376@tf_export(v1=["make_tensor_proto"]) 377def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False, 378 allow_broadcast=False): 379 """Create a TensorProto. 380 381 Args: 382 values: Values to put in the TensorProto. 383 dtype: Optional tensor_pb2 DataType value. 384 shape: List of integers representing the dimensions of tensor. 385 verify_shape: Boolean that enables verification of a shape of values. 386 allow_broadcast:Boolean that enables allowing scalars and 1 length vector 387 broadcasting. Cannot be true when verify_shape is true. 388 389 Returns: 390 A `TensorProto`. Depending on the type, it may contain data in the 391 "tensor_content" attribute, which is not directly useful to Python programs. 392 To access the values you should convert the proto back to a numpy ndarray 393 with `tf.make_ndarray(proto)`. 394 395 If `values` is a `TensorProto`, it is immediately returned; `dtype` and 396 `shape` are ignored. 397 398 Raises: 399 TypeError: if unsupported types are provided. 400 ValueError: if arguments have inappropriate values or if verify_shape is 401 True and shape of values is not equals to a shape from the argument. 402 403 make_tensor_proto accepts "values" of a python scalar, a python list, a 404 numpy ndarray, or a numpy scalar. 405 406 If "values" is a python scalar or a python list, make_tensor_proto 407 first convert it to numpy ndarray. If dtype is None, the 408 conversion tries its best to infer the right numpy data 409 type. Otherwise, the resulting numpy array has a compatible data 410 type with the given dtype. 411 412 In either case above, the numpy ndarray (either the caller provided 413 or the auto converted) must have the compatible type with dtype. 414 415 make_tensor_proto then converts the numpy array to a tensor proto. 416 417 If "shape" is None, the resulting tensor proto represents the numpy 418 array precisely. 419 420 Otherwise, "shape" specifies the tensor's shape and the numpy array 421 can not have more elements than what "shape" specifies. 422 423 """ 424 if allow_broadcast and verify_shape: 425 raise ValueError("allow_broadcast and verify_shape are not both allowed.") 426 if isinstance(values, tensor_pb2.TensorProto): 427 return values 428 429 if dtype: 430 dtype = dtypes.as_dtype(dtype) 431 432 is_quantized = ( 433 dtype in [ 434 dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16, 435 dtypes.qint32 436 ]) 437 438 # We first convert value to a numpy array or scalar. 439 if isinstance(values, (np.ndarray, np.generic)): 440 if dtype: 441 nparray = values.astype(dtype.as_numpy_dtype) 442 else: 443 nparray = values 444 elif callable(getattr(values, "__array__", None)) or isinstance( 445 getattr(values, "__array_interface__", None), dict): 446 # If a class has the __array__ method, or __array_interface__ dict, then it 447 # is possible to convert to numpy array. 448 nparray = np.asarray(values, dtype=dtype) 449 450 # This is the preferred way to create an array from the object, so replace 451 # the `values` with the array so that _FlattenToStrings is not run. 452 values = nparray 453 else: 454 if values is None: 455 raise ValueError("None values not supported.") 456 # if dtype is provided, forces numpy array to be the type 457 # provided if possible. 458 if dtype and dtype.is_numpy_compatible: 459 np_dt = dtype.as_numpy_dtype 460 else: 461 np_dt = None 462 # If shape is None, numpy.prod returns None when dtype is not set, but raises 463 # exception when dtype is set to np.int64 464 if shape is not None and np.prod(shape, dtype=np.int64) == 0: 465 nparray = np.empty(shape, dtype=np_dt) 466 else: 467 _AssertCompatible(values, dtype) 468 nparray = np.array(values, dtype=np_dt) 469 # check to them. 470 # We need to pass in quantized values as tuples, so don't apply the shape 471 if (list(nparray.shape) != _GetDenseDimensions(values) and 472 not is_quantized): 473 raise ValueError("""Argument must be a dense tensor: %s""" 474 """ - got shape %s, but wanted %s.""" % 475 (values, list(nparray.shape), 476 _GetDenseDimensions(values))) 477 478 # python/numpy default float type is float64. We prefer float32 instead. 479 if (nparray.dtype == np.float64) and dtype is None: 480 nparray = nparray.astype(np.float32) 481 # python/numpy default int type is int64. We prefer int32 instead. 482 elif (nparray.dtype == np.int64) and dtype is None: 483 downcasted_array = nparray.astype(np.int32) 484 # Do not down cast if it leads to precision loss. 485 if np.array_equal(downcasted_array, nparray): 486 nparray = downcasted_array 487 488 # if dtype is provided, it must be compatible with what numpy 489 # conversion says. 490 numpy_dtype = dtypes.as_dtype(nparray.dtype) 491 if numpy_dtype is None: 492 raise TypeError("Unrecognized data type: %s" % nparray.dtype) 493 494 # If dtype was specified and is a quantized type, we convert 495 # numpy_dtype back into the quantized version. 496 if is_quantized: 497 numpy_dtype = dtype 498 499 if dtype is not None and (not hasattr(dtype, "base_dtype") or 500 dtype.base_dtype != numpy_dtype.base_dtype): 501 raise TypeError("Incompatible types: %s vs. %s. Value is %s" % 502 (dtype, nparray.dtype, values)) 503 504 # If shape is not given, get the shape from the numpy array. 505 if shape is None: 506 shape = nparray.shape 507 is_same_size = True 508 shape_size = nparray.size 509 else: 510 shape = [int(dim) for dim in shape] 511 shape_size = np.prod(shape, dtype=np.int64) 512 is_same_size = shape_size == nparray.size 513 514 if allow_broadcast: 515 if nparray.shape == (1,) or nparray.shape == tuple(): 516 pass 517 elif nparray.size != shape_size: 518 raise TypeError("Expected Tensor's shape: %s, got %s." % 519 (tuple(shape), nparray.shape)) 520 521 else: 522 if verify_shape and nparray.shape != tuple(shape): 523 raise TypeError("Expected Tensor's shape: %s, got %s." % 524 (tuple(shape), nparray.shape)) 525 526 if nparray.size > shape_size: 527 raise ValueError( 528 "Too many elements provided. Needed at most %d, but received %d" % 529 (shape_size, nparray.size)) 530 531 tensor_proto = tensor_pb2.TensorProto( 532 dtype=numpy_dtype.as_datatype_enum, 533 tensor_shape=tensor_shape.as_shape(shape).as_proto()) 534 535 if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: 536 if nparray.size * nparray.itemsize >= (1 << 31): 537 raise ValueError( 538 "Cannot create a tensor proto whose content is larger than 2GB.") 539 tensor_proto.tensor_content = nparray.tostring() 540 return tensor_proto 541 542 # If we were not given values as a numpy array, compute the proto_values 543 # from the given values directly, to avoid numpy trimming nulls from the 544 # strings. Since values could be a list of strings, or a multi-dimensional 545 # list of lists that might or might not correspond to the given shape, 546 # we flatten it conservatively. 547 if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): 548 proto_values = _FlattenToStrings(values) 549 550 # At this point, values may be a list of objects that we could not 551 # identify a common type for (hence it was inferred as 552 # np.object/dtypes.string). If we are unable to convert it to a 553 # string, we raise a more helpful error message. 554 # 555 # Ideally, we'd be able to convert the elements of the list to a 556 # common type, but this type inference requires some thinking and 557 # so we defer it for now. 558 try: 559 str_values = [compat.as_bytes(x) for x in proto_values] 560 except TypeError: 561 raise TypeError("Failed to convert object of type %s to Tensor. " 562 "Contents: %s. Consider casting elements to a " 563 "supported type." % (type(values), values)) 564 tensor_proto.string_val.extend(str_values) 565 return tensor_proto 566 567 # TensorFlow expects C order (a.k.a., eigen row major). 568 proto_values = nparray.ravel() 569 570 append_fn = GetNumpyAppendFn(proto_values.dtype) 571 if append_fn is None: 572 raise TypeError( 573 "Element type not supported in TensorProto: %s" % numpy_dtype.name) 574 append_fn(tensor_proto, proto_values) 575 576 return tensor_proto 577# pylint: enable=invalid-name 578 579 580@tf_export("make_ndarray") 581def MakeNdarray(tensor): 582 """Create a numpy ndarray from a tensor. 583 584 Create a numpy ndarray with the same shape and data as the tensor. 585 586 Args: 587 tensor: A TensorProto. 588 589 Returns: 590 A numpy array with the tensor contents. 591 592 Raises: 593 TypeError: if tensor has unsupported type. 594 595 """ 596 shape = [d.size for d in tensor.tensor_shape.dim] 597 num_elements = np.prod(shape, dtype=np.int64) 598 tensor_dtype = dtypes.as_dtype(tensor.dtype) 599 dtype = tensor_dtype.as_numpy_dtype 600 601 if tensor.tensor_content: 602 return (np.frombuffer(tensor.tensor_content, 603 dtype=dtype).copy().reshape(shape)) 604 605 if tensor_dtype == dtypes.string: 606 # np.pad throws on these arrays of type np.object. 607 values = list(tensor.string_val) 608 padding = num_elements - len(values) 609 if padding > 0: 610 last = values[-1] if values else "" 611 values.extend([last] * padding) 612 return np.array(values, dtype=dtype).reshape(shape) 613 614 if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16: 615 # the half_val field of the TensorProto stores the binary representation 616 # of the fp16: we need to reinterpret this as a proper float16 617 values = np.fromiter(tensor.half_val, dtype=np.uint16) 618 values.dtype = tensor_dtype.as_numpy_dtype 619 elif tensor_dtype == dtypes.float32: 620 values = np.fromiter(tensor.float_val, dtype=dtype) 621 elif tensor_dtype == dtypes.float64: 622 values = np.fromiter(tensor.double_val, dtype=dtype) 623 elif tensor_dtype in [ 624 dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8, 625 dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16 626 ]: 627 values = np.fromiter(tensor.int_val, dtype=dtype) 628 elif tensor_dtype == dtypes.int64: 629 values = np.fromiter(tensor.int64_val, dtype=dtype) 630 elif tensor_dtype == dtypes.complex64: 631 it = iter(tensor.scomplex_val) 632 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 633 elif tensor_dtype == dtypes.complex128: 634 it = iter(tensor.dcomplex_val) 635 values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype) 636 elif tensor_dtype == dtypes.bool: 637 values = np.fromiter(tensor.bool_val, dtype=dtype) 638 else: 639 raise TypeError("Unsupported tensor type: %s" % tensor.dtype) 640 641 if values.size == 0: 642 return np.zeros(shape, dtype) 643 644 if values.size != num_elements: 645 values = np.pad(values, (0, num_elements - values.size), "edge") 646 647 return values.reshape(shape) 648 649 650def ShapeEquals(tensor_proto, shape): 651 """Returns True if "tensor_proto" has the given "shape". 652 653 Args: 654 tensor_proto: A TensorProto. 655 shape: A tensor shape, expressed as a TensorShape, list, or tuple. 656 657 Returns: 658 True if "tensor_proto" has the given "shape", otherwise False. 659 660 Raises: 661 TypeError: If "tensor_proto" is not a TensorProto, or shape is not a 662 TensorShape, list, or tuple. 663 """ 664 if not isinstance(tensor_proto, tensor_pb2.TensorProto): 665 raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object") 666 if isinstance(shape, tensor_shape_pb2.TensorShapeProto): 667 shape = [d.size for d in shape.dim] 668 elif not isinstance(shape, (list, tuple)): 669 raise TypeError("shape is not a list or tuple") 670 tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim] 671 return all(x == y for x, y in zip(tensor_shape_list, shape)) 672 673 674def _ConstantValue(tensor, partial): 675 # TODO(touts): Support Variables? 676 if not isinstance(tensor, ops.Tensor): 677 raise TypeError("%r is not a Tensor, has type %s" % (tensor, type(tensor))) 678 if tensor.op.type == "Const": 679 return MakeNdarray(tensor.op.get_attr("value")) 680 elif tensor.op.type == "Shape": 681 input_shape = tensor.op.inputs[0].get_shape() 682 if input_shape.is_fully_defined(): 683 return np.array( 684 [dim.value for dim in input_shape.dims], 685 dtype=tensor.dtype.as_numpy_dtype) 686 else: 687 return None 688 elif tensor.op.type == "Size": 689 input_shape = tensor.op.inputs[0].get_shape() 690 if input_shape.is_fully_defined(): 691 return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32) 692 else: 693 return None 694 elif tensor.op.type == "Rank": 695 input_shape = tensor.op.inputs[0].get_shape() 696 if input_shape.ndims is not None: 697 return np.ndarray( 698 shape=(), 699 buffer=np.array([input_shape.ndims], dtype=np.int32), 700 dtype=np.int32) 701 else: 702 return None 703 elif tensor.op.type == "Range": 704 start = constant_value(tensor.op.inputs[0]) 705 if start is None: 706 return None 707 limit = constant_value(tensor.op.inputs[1]) 708 if limit is None: 709 return None 710 delta = constant_value(tensor.op.inputs[2]) 711 if delta is None: 712 return None 713 return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype) 714 elif tensor.op.type == "Cast": 715 pre_cast = constant_value(tensor.op.inputs[0]) 716 if pre_cast is None: 717 return None 718 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 719 return pre_cast.astype(cast_dtype.as_numpy_dtype) 720 elif tensor.op.type == "Concat": 721 dim = constant_value(tensor.op.inputs[0]) 722 if dim is None: 723 return None 724 values = [] 725 for x in tensor.op.inputs[1:]: 726 value = constant_value(x) 727 if value is None: 728 return None 729 values.append(value) 730 return np.concatenate(values, axis=dim) 731 elif tensor.op.type == "ConcatV2": 732 dim = constant_value(tensor.op.inputs[-1]) 733 if dim is None: 734 return None 735 values = [] 736 for x in tensor.op.inputs[:-1]: 737 value = constant_value(x) 738 if value is None: 739 return None 740 values.append(value) 741 return np.concatenate(values, axis=dim) 742 elif tensor.op.type == "Pack": 743 values = [] 744 # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid 745 # and shouldn't be produced, but to deal sensibly with them here we check 746 # and return None. 747 if not tensor.op.inputs: 748 return None 749 # We can't handle axis != 0 Packs at the moment. 750 if tensor.op.get_attr("axis") != 0: 751 return None 752 for x in tensor.op.inputs: 753 value = constant_value(x, partial) 754 if value is None and not partial: 755 return None 756 values.append(value) 757 return np.array(values) 758 elif tensor.op.type == "Fill": 759 fill_shape = tensor.shape 760 fill_value = constant_value(tensor.op.inputs[1]) 761 if fill_shape.is_fully_defined() and fill_value is not None: 762 return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype) 763 else: 764 return None 765 elif tensor.op.type == "Equal": 766 value1 = constant_value(tensor.op.inputs[0]) 767 if value1 is None: 768 return None 769 value2 = constant_value(tensor.op.inputs[1]) 770 if value2 is None: 771 return None 772 return np.equal(value1, value2) 773 elif tensor.op.type == "NotEqual": 774 value1 = constant_value(tensor.op.inputs[0]) 775 if value1 is None: 776 return None 777 value2 = constant_value(tensor.op.inputs[1]) 778 if value2 is None: 779 return None 780 return np.not_equal(value1, value2) 781 else: 782 return None 783 784 785@tf_export('get_static_value') 786def constant_value(tensor, partial=False): # pylint: disable=invalid-name 787 """Returns the constant value of the given tensor, if efficiently calculable. 788 789 This function attempts to partially evaluate the given tensor, and 790 returns its value as a numpy ndarray if this succeeds. 791 792 Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it 793 will no longer be possible to feed a different value for `tensor`. This allows 794 the result of this function to influence the graph that is constructed, and 795 permits static shape optimizations. 796 797 Args: 798 tensor: The Tensor to be evaluated. 799 partial: If True, the returned numpy array is allowed to have partially 800 evaluated values. Values that can't be evaluated will be None. 801 802 Returns: 803 A numpy ndarray containing the constant value of the given `tensor`, 804 or None if it cannot be calculated. 805 806 Raises: 807 TypeError: if tensor is not an ops.Tensor. 808 """ 809 if isinstance(tensor, ops.EagerTensor): 810 return tensor.numpy() 811 if not is_tensor(tensor): 812 return tensor 813 if not isinstance(tensor, ops.Tensor): 814 return None 815 ret = _ConstantValue(tensor, partial) 816 if ret is not None: 817 # The caller may now depend on the constant value of `tensor`, so we 818 # conservatively prevent it from being fed. 819 tensor.graph.prevent_feeding(tensor) 820 return ret 821 822 823def constant_value_as_shape(tensor): # pylint: disable=invalid-name 824 """A version of `constant_value()` that returns a `TensorShape`. 825 826 This version should be used when a constant tensor value is 827 interpreted as a (possibly partial) shape, e.g. in the shape 828 function for `tf.reshape()`. By explicitly requesting a 829 `TensorShape` as the return value, it is possible to represent 830 unknown dimensions; by contrast, `constant_value()` is 831 all-or-nothing. 832 833 Args: 834 tensor: The rank-0 or rank-1 Tensor to be evaluated. 835 836 Returns: 837 A `TensorShape` based on the constant value of the given `tensor`. 838 839 Raises: 840 ValueError: If the shape is rank-0 and is not statically known to be -1. 841 """ 842 if isinstance(tensor, ops.EagerTensor): 843 return tensor_shape.as_shape( 844 [dim if dim != -1 else None for dim in tensor.numpy()]) 845 846 if tensor.get_shape().ndims == 0: 847 value = constant_value(tensor) 848 if value is None: 849 raise ValueError( 850 "Received a scalar with unknown value as shape; require a statically " 851 "known scalar with value '-1' to describe an unknown shape.") 852 if value != -1: 853 raise ValueError( 854 "Received a scalar value '%s' as shape; require a statically known " 855 "scalar with value '-1' to describe an unknown shape." % value) 856 return tensor_shape.unknown_shape() 857 858 shape = tensor.get_shape().with_rank(1) 859 if shape == [0]: 860 return tensor_shape.scalar() 861 elif tensor.op.type == "Shape": 862 return tensor.op.inputs[0].get_shape() 863 elif tensor.op.type == "Pack": 864 ret = tensor_shape.scalar() # Empty list. 865 # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it 866 # would not be rank 1. 867 assert tensor.op.get_attr("axis") == 0 868 for pack_input in tensor.op.inputs: 869 # `pack_input` must be a scalar. Attempt to evaluate it, and append it 870 # to `ret`. 871 pack_input_val = constant_value(pack_input) 872 if pack_input_val is None or pack_input_val < 0: 873 new_dim = tensor_shape.Dimension(None) 874 else: 875 new_dim = tensor_shape.Dimension(pack_input_val) 876 ret = ret.concatenate([new_dim]) 877 return ret 878 elif tensor.op.type == "Concat": 879 # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is 880 # the only legal value when concatenating vectors, and it will 881 # have been checked by a previous shape function. 882 ret = tensor_shape.scalar() # Empty list. 883 for concat_input in tensor.op.inputs[1:]: 884 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 885 # and concatenate it with `ret`. 886 ret = ret.concatenate(constant_value_as_shape(concat_input)) 887 return ret 888 elif tensor.op.type == "ConcatV2": 889 # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is 890 # the only legal value when concatenating vectors, and it will 891 # have been checked by a previous shape function. 892 ret = tensor_shape.scalar() # Empty list. 893 for concat_input in tensor.op.inputs[:-1]: 894 # `concat_input` must be a vector. Attempt to evaluate it as a shape, 895 # and concatenate it with `ret`. 896 ret = ret.concatenate(constant_value_as_shape(concat_input)) 897 return ret 898 elif tensor.op.type == "StridedSlice": 899 try: 900 begin = constant_value(tensor.op.inputs[1]) 901 end = constant_value(tensor.op.inputs[2]) 902 strides = constant_value(tensor.op.inputs[3]) 903 if begin is not None and end is not None and strides is not None: 904 begin = begin[0] 905 end = end[0] 906 strides = strides[0] 907 begin_mask = tensor.op.get_attr("begin_mask") 908 if begin_mask == 1: 909 begin = None 910 end_mask = tensor.op.get_attr("end_mask") 911 if end_mask == 1: 912 end = None 913 914 ellipsis_mask = tensor.op.get_attr("ellipsis_mask") 915 new_axis_mask = tensor.op.get_attr("new_axis_mask") 916 shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask") 917 valid_attributes = (not ellipsis_mask and not new_axis_mask and 918 not shrink_axis_mask and (not begin_mask or 919 (begin_mask == 1)) and 920 (not end_mask or (end_mask == 1))) 921 if valid_attributes: # additional inputs not supported 922 prev = constant_value_as_shape(tensor.op.inputs[0]) 923 prev = prev[begin:end:strides] 924 ret = tensor_shape.TensorShape(prev) 925 return ret 926 927 except ValueError: # Could come from get_attr or slicing prev. 928 pass 929 except TypeError: # Could come from slicing prev. 930 pass 931 932 ret = tensor_shape.unknown_shape(shape.dims[0].value) 933 value = constant_value(tensor) 934 if value is not None: 935 ret = ret.merge_with( 936 tensor_shape.TensorShape([d if d >= 0 else None for d in value])) 937 return ret 938 939 940@tf_export("is_tensor") 941def is_tensor(x): # pylint: disable=invalid-name 942 """Check whether `x` is of tensor type. 943 944 Check whether an object is a tensor or a composite tensor. This check is 945 equivalent to calling 946 `isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Variable))` 947 and also checks if all the component variables of a MirroredVariable or a 948 SyncOnReadVariable are tensors. 949 950 Args: 951 x: A python object to check. 952 953 Returns: 954 `True` if `x` is a tensor, `False` if not. 955 """ 956 return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access 957 isinstance(x, composite_tensor.CompositeTensor) or 958 (hasattr(x, "is_tensor_like") and x.is_tensor_like)) 959