• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to create TensorProtos."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import six
22
23from tensorflow.core.framework import tensor_pb2
24from tensorflow.core.framework import tensor_shape_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors_impl
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.types import core
31from tensorflow.python.types import internal
32from tensorflow.python.util import compat
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36# Fallback in case fast_tensor_util is not properly compiled.
37# pylint: disable=g-import-not-at-top
38try:
39  from tensorflow.python.framework import fast_tensor_util
40  _FAST_TENSOR_UTIL_AVAILABLE = True
41except ImportError:
42  _FAST_TENSOR_UTIL_AVAILABLE = False
43# pylint: enable=g-import-not-at-top
44
45
46def ExtractBitsFromFloat16(x):
47  return np.asarray(x, dtype=np.float16).view(np.uint16).item()
48
49
50def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
51  tensor_proto.half_val.extend(
52      [ExtractBitsFromFloat16(x) for x in proto_values])
53
54
55def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
56  # TODO: Remove the conversion if cython supports np.float16_t
57  fast_tensor_util.AppendFloat16ArrayToTensorProto(
58      tensor_proto,
59      np.asarray(proto_values, dtype=np.float16).view(np.uint16))
60
61
62def ExtractBitsFromBFloat16(x):
63  return np.asarray(
64      x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item()
65
66
67def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
68  tensor_proto.half_val.extend(
69      [ExtractBitsFromBFloat16(x) for x in proto_values])
70
71
72def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
73  fast_tensor_util.AppendBFloat16ArrayToTensorProto(
74      tensor_proto, np.asarray(
75          proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
76
77
78if _FAST_TENSOR_UTIL_AVAILABLE:
79  _NP_TO_APPEND_FN = {
80      dtypes.bfloat16.as_numpy_dtype:
81          FastAppendBFloat16ArrayToTensorProto,
82      np.float16:
83          _MediumAppendFloat16ArrayToTensorProto,
84      np.float32:
85          fast_tensor_util.AppendFloat32ArrayToTensorProto,
86      np.float64:
87          fast_tensor_util.AppendFloat64ArrayToTensorProto,
88      np.int32:
89          fast_tensor_util.AppendInt32ArrayToTensorProto,
90      np.int64:
91          fast_tensor_util.AppendInt64ArrayToTensorProto,
92      np.uint8:
93          fast_tensor_util.AppendUInt8ArrayToTensorProto,
94      np.uint16:
95          fast_tensor_util.AppendUInt16ArrayToTensorProto,
96      np.uint32:
97          fast_tensor_util.AppendUInt32ArrayToTensorProto,
98      np.uint64:
99          fast_tensor_util.AppendUInt64ArrayToTensorProto,
100      np.int8:
101          fast_tensor_util.AppendInt8ArrayToTensorProto,
102      np.int16:
103          fast_tensor_util.AppendInt16ArrayToTensorProto,
104      np.complex64:
105          fast_tensor_util.AppendComplex64ArrayToTensorProto,
106      np.complex128:
107          fast_tensor_util.AppendComplex128ArrayToTensorProto,
108      np.object:
109          fast_tensor_util.AppendObjectArrayToTensorProto,
110      np.bool:
111          fast_tensor_util.AppendBoolArrayToTensorProto,
112      dtypes.qint8.as_numpy_dtype:
113          fast_tensor_util.AppendInt8ArrayToTensorProto,
114      dtypes.quint8.as_numpy_dtype:
115          fast_tensor_util.AppendUInt8ArrayToTensorProto,
116      dtypes.qint16.as_numpy_dtype:
117          fast_tensor_util.AppendInt16ArrayToTensorProto,
118      dtypes.quint16.as_numpy_dtype:
119          fast_tensor_util.AppendUInt16ArrayToTensorProto,
120      dtypes.qint32.as_numpy_dtype:
121          fast_tensor_util.AppendInt32ArrayToTensorProto,
122      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
123  }
124else:
125
126  def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
127    tensor_proto.float_val.extend([x.item() for x in proto_values])
128
129  def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
130    tensor_proto.double_val.extend([x.item() for x in proto_values])
131
132  def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
133    tensor_proto.int_val.extend([x.item() for x in proto_values])
134
135  def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
136    tensor_proto.int64_val.extend([x.item() for x in proto_values])
137
138  def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
139    tensor_proto.int_val.extend([x.item()[0] for x in proto_values])
140
141  def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
142    tensor_proto.uint32_val.extend([x.item() for x in proto_values])
143
144  def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
145    tensor_proto.uint64_val.extend([x.item() for x in proto_values])
146
147  def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
148    tensor_proto.scomplex_val.extend(
149        [v.item() for x in proto_values for v in [x.real, x.imag]])
150
151  def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
152    tensor_proto.dcomplex_val.extend(
153        [v.item() for x in proto_values for v in [x.real, x.imag]])
154
155  def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
156    tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
157
158  def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
159    tensor_proto.bool_val.extend([x.item() for x in proto_values])
160
161  _NP_TO_APPEND_FN = {
162      dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto,
163      np.float16: SlowAppendFloat16ArrayToTensorProto,
164      np.float32: SlowAppendFloat32ArrayToTensorProto,
165      np.float64: SlowAppendFloat64ArrayToTensorProto,
166      np.int32: SlowAppendIntArrayToTensorProto,
167      np.int64: SlowAppendInt64ArrayToTensorProto,
168      np.uint8: SlowAppendIntArrayToTensorProto,
169      np.uint16: SlowAppendIntArrayToTensorProto,
170      np.uint32: SlowAppendUInt32ArrayToTensorProto,
171      np.uint64: SlowAppendUInt64ArrayToTensorProto,
172      np.int8: SlowAppendIntArrayToTensorProto,
173      np.int16: SlowAppendIntArrayToTensorProto,
174      np.complex64: SlowAppendComplex64ArrayToTensorProto,
175      np.complex128: SlowAppendComplex128ArrayToTensorProto,
176      np.object: SlowAppendObjectArrayToTensorProto,
177      np.bool: SlowAppendBoolArrayToTensorProto,
178      dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
179      dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
180      dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
181      dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
182      dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
183      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
184  }
185
186
187def GetFromNumpyDTypeDict(dtype_dict, dtype):
188  # NOTE: dtype_dict.get(dtype) always returns None.
189  for key, val in six.iteritems(dtype_dict):
190    if key == dtype:
191      return val
192  return None
193
194
195def GetNumpyAppendFn(dtype):
196  # numpy dtype for strings are variable length. We can not compare
197  # dtype with a single constant (np.string does not exist) to decide
198  # dtype is a "string" type. We need to compare the dtype.type to be
199  # sure it's a string type.
200  if dtype.type == np.string_ or dtype.type == np.unicode_:
201    if _FAST_TENSOR_UTIL_AVAILABLE:
202      return fast_tensor_util.AppendObjectArrayToTensorProto
203    else:
204      return SlowAppendObjectArrayToTensorProto
205  return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
206
207
208def TensorShapeProtoToList(shape):
209  """Convert a TensorShape to a list.
210
211  Args:
212    shape: A TensorShapeProto.
213
214  Returns:
215    List of integers representing the dimensions of the tensor.
216  """
217  return [dim.size for dim in shape.dim]
218
219
220def _GetDenseDimensions(list_of_lists):
221  """Returns the inferred dense dimensions of a list of lists."""
222  if not isinstance(list_of_lists, (list, tuple)):
223    return []
224  elif not list_of_lists:
225    return [0]
226  else:
227    return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
228
229
230def _FlattenToStrings(nested_strings):
231  if isinstance(nested_strings, (list, tuple)):
232    for inner in nested_strings:
233      for flattened_string in _FlattenToStrings(inner):
234        yield flattened_string
235  else:
236    yield nested_strings
237
238
239_TENSOR_CONTENT_TYPES = frozenset([
240    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8,
241    dtypes.int16, dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8,
242    dtypes.qint16, dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64
243])
244
245
246# pylint: disable=invalid-name
247def _check_failed(v):
248  # NB. none of the _check_* functions could raise a ValueError, so
249  # it is safe to use here.
250  raise ValueError(v)
251
252
253def _check_quantized(values):
254  # Cannot rely on `nest` because the leaves are tuples.
255  if not isinstance(values, (list, tuple)):
256    _check_failed(values)
257  if isinstance(values, tuple):
258    _ = [_check_int(v) for v in values]
259  else:
260    _ = [_check_quantized(v) for v in values]
261
262
263def _generate_isinstance_check(expected_types):
264  def inner(values):
265    for v in nest.flatten(values):
266      if not (isinstance(v, expected_types) or
267              (isinstance(v, np.ndarray) and
268               issubclass(v.dtype.type, expected_types))):
269        _check_failed(v)
270
271  return inner
272
273_check_int = _generate_isinstance_check(
274    (compat.integral_types, tensor_shape.Dimension))
275_check_float = _generate_isinstance_check(compat.real_types)
276_check_complex = _generate_isinstance_check(compat.complex_types)
277_check_str = _generate_isinstance_check(compat.bytes_or_text_types)
278_check_bool = _generate_isinstance_check(bool)
279
280
281def _check_not_tensor(values):
282  _ = [_check_failed(v) for v in nest.flatten(values)
283       if isinstance(v, ops.Tensor)]
284# pylint: enable=invalid-name
285
286_TF_TO_IS_OK = {
287    dtypes.bool: _check_bool,
288    dtypes.complex128: _check_complex,
289    dtypes.complex64: _check_complex,
290    dtypes.float16: _check_float,
291    dtypes.float32: _check_float,
292    dtypes.float64: _check_float,
293    dtypes.int16: _check_int,
294    dtypes.int32: _check_int,
295    dtypes.int64: _check_int,
296    dtypes.int8: _check_int,
297    dtypes.qint16: _check_quantized,
298    dtypes.qint32: _check_quantized,
299    dtypes.qint8: _check_quantized,
300    dtypes.quint16: _check_quantized,
301    dtypes.quint8: _check_quantized,
302    dtypes.string: _check_str,
303    dtypes.uint16: _check_int,
304    dtypes.uint8: _check_int,
305    dtypes.uint32: _check_int,
306    dtypes.uint64: _check_int,
307}
308
309
310def _AssertCompatible(values, dtype):
311  if dtype is None:
312    fn = _check_not_tensor
313  else:
314    try:
315      fn = _TF_TO_IS_OK[dtype]
316    except KeyError:
317      # There isn't a specific fn, so we try to do the best possible.
318      if dtype.is_integer:
319        fn = _check_int
320      elif dtype.is_floating:
321        fn = _check_float
322      elif dtype.is_complex:
323        fn = _check_complex
324      elif dtype.is_quantized:
325        fn = _check_quantized
326      else:
327        fn = _check_not_tensor
328
329  try:
330    fn(values)
331  except ValueError as e:
332    [mismatch] = e.args
333    if dtype is None:
334      raise TypeError("Expected any non-tensor type, got a tensor instead.")
335    else:
336      raise TypeError("Expected %s, got %s of type '%s' instead." %
337                      (dtype.name, repr(mismatch), type(mismatch).__name__))
338
339
340def _is_array_like(obj):  # pylint: disable=invalid-name
341  """Check if a given object is array-like."""
342  if isinstance(obj, ops.Tensor) and not isinstance(obj, ops._EagerTensorBase):  # pylint: disable=protected-access
343    # Tensor implements __array__ only so it can inform the user that it is not
344    # a valid array.
345    return False
346
347  # TODO(slebedev): an object could also implement C-level array interface.
348  if (callable(getattr(obj, "__array__", None)) or
349      isinstance(getattr(obj, "__array_interface__", None), dict)):
350    return True
351
352  try:
353    memoryview(obj)
354  except TypeError:
355    return False
356  else:
357    return not isinstance(obj, bytes)
358
359
360# pylint: disable=invalid-name
361@tf_export("make_tensor_proto")
362def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
363                      allow_broadcast=False):
364  """Create a TensorProto.
365
366  In TensorFlow 2.0, representing tensors as protos should no longer be a
367  common workflow. That said, this utility function is still useful for
368  generating TF Serving request protos:
369
370  ```python
371    request = tensorflow_serving.apis.predict_pb2.PredictRequest()
372    request.model_spec.name = "my_model"
373    request.model_spec.signature_name = "serving_default"
374    request.inputs["images"].CopyFrom(tf.make_tensor_proto(X_new))
375  ```
376
377  `make_tensor_proto` accepts "values" of a python scalar, a python list, a
378  numpy ndarray, or a numpy scalar.
379
380  If "values" is a python scalar or a python list, make_tensor_proto
381  first convert it to numpy ndarray. If dtype is None, the
382  conversion tries its best to infer the right numpy data
383  type. Otherwise, the resulting numpy array has a compatible data
384  type with the given dtype.
385
386  In either case above, the numpy ndarray (either the caller provided
387  or the auto-converted) must have the compatible type with dtype.
388
389  `make_tensor_proto` then converts the numpy array to a tensor proto.
390
391  If "shape" is None, the resulting tensor proto represents the numpy
392  array precisely.
393
394  Otherwise, "shape" specifies the tensor's shape and the numpy array
395  can not have more elements than what "shape" specifies.
396
397  Args:
398    values:         Values to put in the TensorProto.
399    dtype:          Optional tensor_pb2 DataType value.
400    shape:          List of integers representing the dimensions of tensor.
401    verify_shape:   Boolean that enables verification of a shape of values.
402    allow_broadcast:  Boolean that enables allowing scalars and 1 length vector
403        broadcasting. Cannot be true when verify_shape is true.
404
405  Returns:
406    A `TensorProto`. Depending on the type, it may contain data in the
407    "tensor_content" attribute, which is not directly useful to Python programs.
408    To access the values you should convert the proto back to a numpy ndarray
409    with `tf.make_ndarray(proto)`.
410
411    If `values` is a `TensorProto`, it is immediately returned; `dtype` and
412    `shape` are ignored.
413
414  Raises:
415    TypeError:  if unsupported types are provided.
416    ValueError: if arguments have inappropriate values or if verify_shape is
417     True and shape of values is not equals to a shape from the argument.
418
419  """
420  if allow_broadcast and verify_shape:
421    raise ValueError("allow_broadcast and verify_shape are not both allowed.")
422  if isinstance(values, tensor_pb2.TensorProto):
423    return values
424
425  if dtype:
426    dtype = dtypes.as_dtype(dtype)
427
428  is_quantized = (
429      dtype in [
430          dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
431          dtypes.qint32
432      ])
433
434  if _is_array_like(values):
435    values = np.asarray(values)
436
437  # We first convert value to a numpy array or scalar.
438  if isinstance(values, (np.ndarray, np.generic)):
439    if dtype and dtype.is_numpy_compatible:
440      nparray = values.astype(dtype.as_numpy_dtype)
441    else:
442      nparray = values
443  else:
444    if values is None:
445      raise ValueError("None values not supported.")
446    # if dtype is provided, forces numpy array to be the type
447    # provided if possible.
448    if dtype and dtype.is_numpy_compatible:
449      np_dt = dtype.as_numpy_dtype
450    else:
451      np_dt = None
452    # If shape is None, numpy.prod returns None when dtype is not set, but
453    # raises exception when dtype is set to np.int64
454    if shape is not None and np.prod(shape, dtype=np.int64) == 0:
455      nparray = np.empty(shape, dtype=np_dt)
456    else:
457      _AssertCompatible(values, dtype)
458      nparray = np.array(values, dtype=np_dt)
459      # check to them.
460      # We need to pass in quantized values as tuples, so don't apply the shape
461      if (list(nparray.shape) != _GetDenseDimensions(values) and
462          not is_quantized):
463        raise ValueError("""Argument must be a dense tensor: %s"""
464                         """ - got shape %s, but wanted %s.""" %
465                         (values, list(nparray.shape),
466                          _GetDenseDimensions(values)))
467
468    # python/numpy default float type is float64. We prefer float32 instead.
469    if (nparray.dtype == np.float64) and dtype is None:
470      nparray = nparray.astype(np.float32)
471    # python/numpy default int type is int64. We prefer int32 instead.
472    elif (nparray.dtype == np.int64) and dtype is None:
473      downcasted_array = nparray.astype(np.int32)
474      # Do not down cast if it leads to precision loss.
475      if np.array_equal(downcasted_array, nparray):
476        nparray = downcasted_array
477
478  # if dtype is provided, it must be compatible with what numpy
479  # conversion says.
480  numpy_dtype = dtypes.as_dtype(nparray.dtype)
481  if numpy_dtype is None:
482    raise TypeError("Unrecognized data type: %s" % nparray.dtype)
483
484  # If dtype was specified and is a quantized type, we convert
485  # numpy_dtype back into the quantized version.
486  if is_quantized:
487    numpy_dtype = dtype
488
489  if dtype is not None and (not hasattr(dtype, "base_dtype") or
490                            dtype.base_dtype != numpy_dtype.base_dtype):
491    raise TypeError("Incompatible types: %s vs. %s. Value is %s" %
492                    (dtype, nparray.dtype, values))
493
494  # If shape is not given, get the shape from the numpy array.
495  if shape is None:
496    shape = nparray.shape
497    is_same_size = True
498    shape_size = nparray.size
499  else:
500    shape = [int(dim) for dim in shape]
501    shape_size = np.prod(shape, dtype=np.int64)
502    is_same_size = shape_size == nparray.size
503
504    if allow_broadcast:
505      if nparray.shape == (1,) or nparray.shape == tuple():
506        pass
507      elif nparray.size != shape_size:
508        raise TypeError("Expected Tensor's shape: %s, got %s." %
509                        (tuple(shape), nparray.shape))
510
511    else:
512      if verify_shape and nparray.shape != tuple(shape):
513        raise TypeError("Expected Tensor's shape: %s, got %s." %
514                        (tuple(shape), nparray.shape))
515
516      if nparray.size > shape_size:
517        raise ValueError(
518            "Too many elements provided. Needed at most %d, but received %d" %
519            (shape_size, nparray.size))
520
521  tensor_proto = tensor_pb2.TensorProto(
522      dtype=numpy_dtype.as_datatype_enum,
523      tensor_shape=tensor_shape.as_shape(shape).as_proto())
524
525  if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
526    if nparray.size * nparray.itemsize >= (1 << 31):
527      raise ValueError(
528          "Cannot create a tensor proto whose content is larger than 2GB.")
529    tensor_proto.tensor_content = nparray.tobytes()
530    return tensor_proto
531
532  # If we were not given values as a numpy array, compute the proto_values
533  # from the given values directly, to avoid numpy trimming nulls from the
534  # strings. Since values could be a list of strings, or a multi-dimensional
535  # list of lists that might or might not correspond to the given shape,
536  # we flatten it conservatively.
537  if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
538    proto_values = _FlattenToStrings(values)
539
540    # At this point, values may be a list of objects that we could not
541    # identify a common type for (hence it was inferred as
542    # np.object/dtypes.string).  If we are unable to convert it to a
543    # string, we raise a more helpful error message.
544    #
545    # Ideally, we'd be able to convert the elements of the list to a
546    # common type, but this type inference requires some thinking and
547    # so we defer it for now.
548    try:
549      str_values = [compat.as_bytes(x) for x in proto_values]
550    except TypeError:
551      raise TypeError("Failed to convert object of type %s to Tensor. "
552                      "Contents: %s. Consider casting elements to a "
553                      "supported type." % (type(values), values))
554    tensor_proto.string_val.extend(str_values)
555    return tensor_proto
556
557  # TensorFlow expects C order (a.k.a., eigen row major).
558  proto_values = nparray.ravel()
559
560  append_fn = GetNumpyAppendFn(proto_values.dtype)
561  if append_fn is None:
562    raise TypeError(
563        "Element type not supported in TensorProto: %s" % numpy_dtype.name)
564  append_fn(tensor_proto, proto_values)
565
566  return tensor_proto
567# pylint: enable=invalid-name
568
569
570@tf_export("make_ndarray")
571def MakeNdarray(tensor):
572  """Create a numpy ndarray from a tensor.
573
574  Create a numpy ndarray with the same shape and data as the tensor.
575
576  For example:
577
578  ```python
579  # Tensor a has shape (2,3)
580  a = tf.constant([[1,2,3],[4,5,6]])
581  proto_tensor = tf.make_tensor_proto(a)  # convert `tensor a` to a proto tensor
582  tf.make_ndarray(proto_tensor) # output: array([[1, 2, 3],
583  #                                              [4, 5, 6]], dtype=int32)
584  # output has shape (2,3)
585  ```
586
587  Args:
588    tensor: A TensorProto.
589
590  Returns:
591    A numpy array with the tensor contents.
592
593  Raises:
594    TypeError: if tensor has unsupported type.
595
596  """
597  shape = [d.size for d in tensor.tensor_shape.dim]
598  num_elements = np.prod(shape, dtype=np.int64)
599  tensor_dtype = dtypes.as_dtype(tensor.dtype)
600  dtype = tensor_dtype.as_numpy_dtype
601
602  if tensor.tensor_content:
603    return (np.frombuffer(tensor.tensor_content,
604                          dtype=dtype).copy().reshape(shape))
605
606  if tensor_dtype == dtypes.string:
607    # np.pad throws on these arrays of type np.object.
608    values = list(tensor.string_val)
609    padding = num_elements - len(values)
610    if padding > 0:
611      last = values[-1] if values else ""
612      values.extend([last] * padding)
613    return np.array(values, dtype=dtype).reshape(shape)
614
615  if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
616    # the half_val field of the TensorProto stores the binary representation
617    # of the fp16: we need to reinterpret this as a proper float16
618    values = np.fromiter(tensor.half_val, dtype=np.uint16)
619    values.dtype = tensor_dtype.as_numpy_dtype
620  elif tensor_dtype == dtypes.float32:
621    values = np.fromiter(tensor.float_val, dtype=dtype)
622  elif tensor_dtype == dtypes.float64:
623    values = np.fromiter(tensor.double_val, dtype=dtype)
624  elif tensor_dtype in [
625      dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
626      dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
627  ]:
628    values = np.fromiter(tensor.int_val, dtype=dtype)
629  elif tensor_dtype == dtypes.int64:
630    values = np.fromiter(tensor.int64_val, dtype=dtype)
631  elif tensor_dtype == dtypes.uint32:
632    values = np.fromiter(tensor.uint32_val, dtype=dtype)
633  elif tensor_dtype == dtypes.uint64:
634    values = np.fromiter(tensor.uint64_val, dtype=dtype)
635  elif tensor_dtype == dtypes.complex64:
636    it = iter(tensor.scomplex_val)
637    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
638  elif tensor_dtype == dtypes.complex128:
639    it = iter(tensor.dcomplex_val)
640    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
641  elif tensor_dtype == dtypes.bool:
642    values = np.fromiter(tensor.bool_val, dtype=dtype)
643  else:
644    raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
645
646  if values.size == 0:
647    return np.zeros(shape, dtype)
648
649  if values.size != num_elements:
650    values = np.pad(values, (0, num_elements - values.size), "edge")
651
652  return values.reshape(shape)
653
654
655def ShapeEquals(tensor_proto, shape):
656  """Returns True if "tensor_proto" has the given "shape".
657
658  Args:
659    tensor_proto: A TensorProto.
660    shape: A tensor shape, expressed as a TensorShape, list, or tuple.
661
662  Returns:
663    True if "tensor_proto" has the given "shape", otherwise False.
664
665  Raises:
666    TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
667      TensorShape, list, or tuple.
668  """
669  if not isinstance(tensor_proto, tensor_pb2.TensorProto):
670    raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object")
671  if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
672    shape = [d.size for d in shape.dim]
673  elif not isinstance(shape, (list, tuple)):
674    raise TypeError("shape is not a list or tuple")
675  tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
676  return all(x == y for x, y in zip(tensor_shape_list, shape))
677
678
679def _ConstantValue(tensor, partial):
680  # TODO(touts): Support Variables?
681  if not isinstance(tensor, ops.Tensor):
682    raise TypeError("%r is not a Tensor, has type %s" % (tensor, type(tensor)))
683  if tensor.op.type == "Const":
684    return MakeNdarray(tensor.op.get_attr("value"))
685  elif tensor.op.type == "Shape":
686    input_shape = tensor.op.inputs[0].get_shape()
687    if input_shape.is_fully_defined():
688      return np.array(
689          [dim.value for dim in input_shape.dims],
690          dtype=tensor.dtype.as_numpy_dtype)
691    else:
692      return None
693  elif tensor.op.type == "Size":
694    input_shape = tensor.op.inputs[0].get_shape()
695    if input_shape.is_fully_defined():
696      return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
697    else:
698      return None
699  elif tensor.op.type == "Rank":
700    input_shape = tensor.op.inputs[0].get_shape()
701    if input_shape.ndims is not None:
702      return np.ndarray(
703          shape=(),
704          buffer=np.array([input_shape.ndims], dtype=np.int32),
705          dtype=np.int32)
706    else:
707      return None
708  elif tensor.op.type == "Range":
709    start = constant_value(tensor.op.inputs[0])
710    if start is None:
711      return None
712    limit = constant_value(tensor.op.inputs[1])
713    if limit is None:
714      return None
715    delta = constant_value(tensor.op.inputs[2])
716    if delta is None:
717      return None
718    return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
719  elif tensor.op.type == "Cast":
720    pre_cast = constant_value(tensor.op.inputs[0])
721    if pre_cast is None:
722      return None
723    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
724    return pre_cast.astype(cast_dtype.as_numpy_dtype)
725  elif tensor.op.type == "Concat":
726    dim = constant_value(tensor.op.inputs[0])
727    if dim is None:
728      return None
729    values = []
730    for x in tensor.op.inputs[1:]:
731      value = constant_value(x)
732      if value is None:
733        return None
734      values.append(value)
735    return np.concatenate(values, axis=dim)
736  elif tensor.op.type == "ConcatV2":
737    dim = constant_value(tensor.op.inputs[-1])
738    if dim is None:
739      return None
740    values = []
741    for x in tensor.op.inputs[:-1]:
742      value = constant_value(x)
743      if value is None:
744        return None
745      values.append(value)
746    return np.concatenate(values, axis=dim)
747  elif tensor.op.type == "Pack":
748    values = []
749    # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
750    # and shouldn't be produced, but to deal sensibly with them here we check
751    # and return None.
752    if not tensor.op.inputs:
753      return None
754    # We can't handle axis != 0 Packs at the moment.
755    if tensor.op.get_attr("axis") != 0:
756      return None
757    for x in tensor.op.inputs:
758      value = constant_value(x, partial)
759      if value is None and not partial:
760        return None
761      values.append(value)
762    return np.array(values)
763  elif tensor.op.type == "Unpack":
764    # We can't handle axis != 0 Unpacks at the moment.
765    if tensor.op.get_attr("axis") != 0:
766      return None
767    value = constant_value(tensor.op.inputs[0], partial)
768    if value is None:
769      return None
770    return value[tensor.value_index]
771  elif tensor.op.type == "Split":
772    dim = constant_value(tensor.op.inputs[0])
773    value = constant_value(tensor.op.inputs[1], partial)
774    if value is None or dim is None:
775      return None
776    split = np.split(value, tensor.op.get_attr("num_split"), dim)
777    return split[tensor.value_index]
778  elif tensor.op.type == "Fill":
779    fill_shape = tensor.shape
780    fill_value = constant_value(tensor.op.inputs[1])
781    if fill_shape.is_fully_defined() and fill_value is not None:
782      return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
783    else:
784      return None
785  elif tensor.op.type == "Equal":
786    value1 = constant_value(tensor.op.inputs[0])
787    if value1 is None:
788      return None
789    value2 = constant_value(tensor.op.inputs[1])
790    if value2 is None:
791      return None
792    return np.equal(value1, value2)
793  elif tensor.op.type == "NotEqual":
794    value1 = constant_value(tensor.op.inputs[0])
795    if value1 is None:
796      return None
797    value2 = constant_value(tensor.op.inputs[1])
798    if value2 is None:
799      return None
800    return np.not_equal(value1, value2)
801  elif tensor.op.type == "StopGradient":
802    return constant_value(tensor.op.inputs[0], partial)
803  elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2", "Identity"):
804    return constant_value(tensor.op.inputs[0], partial)
805  else:
806    return None
807
808
809@tf_export("get_static_value")
810def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
811  """Returns the constant value of the given tensor, if efficiently calculable.
812
813  This function attempts to partially evaluate the given tensor, and
814  returns its value as a numpy ndarray if this succeeds.
815
816  Example usage:
817
818  >>> a = tf.constant(10)
819  >>> tf.get_static_value(a)
820  10
821  >>> b = tf.constant(20)
822  >>> tf.get_static_value(tf.add(a, b))
823  30
824
825  >>> # `tf.Variable` is not supported.
826  >>> c = tf.Variable(30)
827  >>> print(tf.get_static_value(c))
828  None
829
830  Using `partial` option is most relevant when calling `get_static_value` inside
831  a `tf.function`. Setting it to `True` will return the results but for the
832  values that cannot be evaluated will be `None`. For example:
833
834  ```python
835  class Foo(object):
836    def __init__(self):
837      self.a = tf.Variable(1)
838      self.b = tf.constant(2)
839
840    @tf.function
841    def bar(self, partial):
842      packed = tf.raw_ops.Pack(values=[self.a, self.b])
843      static_val = tf.get_static_value(packed, partial=partial)
844      tf.print(static_val)
845
846  f = Foo()
847  f.bar(partial=True)  # `array([None, array(2, dtype=int32)], dtype=object)`
848  f.bar(partial=False)  # `None`
849  ```
850
851  Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
852  will no longer be possible to feed a different value for `tensor`. This allows
853  the result of this function to influence the graph that is constructed, and
854  permits static shape optimizations.
855
856  Args:
857    tensor: The Tensor to be evaluated.
858    partial: If True, the returned numpy array is allowed to have partially
859      evaluated values. Values that can't be evaluated will be None.
860
861  Returns:
862    A numpy ndarray containing the constant value of the given `tensor`,
863    or None if it cannot be calculated.
864
865  Raises:
866    TypeError: if tensor is not an ops.Tensor.
867  """
868  if isinstance(tensor, ops.EagerTensor):
869    try:
870      return tensor.numpy()
871    except errors_impl.UnimplementedError:
872      # Some EagerTensors may not implement .numpy/resolve, e.g. parallel
873      # tensors with multiple components on different devices.
874      return None
875  if not is_tensor(tensor):
876    return tensor
877  if not isinstance(tensor, ops.Tensor):
878    return None
879  ret = _ConstantValue(tensor, partial)
880  if ret is not None:
881    # The caller may now depend on the constant value of `tensor`, so we
882    # conservatively prevent it from being fed.
883    tensor.graph.prevent_feeding(tensor)
884  return ret
885
886
887def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
888  """A version of `constant_value()` that returns a `TensorShape`.
889
890  This version should be used when a constant tensor value is
891  interpreted as a (possibly partial) shape, e.g. in the shape
892  function for `tf.reshape()`. By explicitly requesting a
893  `TensorShape` as the return value, it is possible to represent
894  unknown dimensions; by contrast, `constant_value()` is
895  all-or-nothing.
896
897  Args:
898    tensor: The rank-0 or rank-1 Tensor to be evaluated.
899
900  Returns:
901    A `TensorShape` based on the constant value of the given `tensor`.
902
903  Raises:
904    ValueError: If the shape is rank-0 and is not statically known to be -1.
905  """
906  if isinstance(tensor, ops.EagerTensor):
907    return tensor_shape.TensorShape(
908        [dim if dim != -1 else None for dim in tensor.numpy()])
909
910  if tensor.get_shape().ndims == 0:
911    value = constant_value(tensor)
912    if value is None:
913      raise ValueError(
914          "Received a scalar with unknown value as shape; require a statically "
915          "known scalar with value '-1' to describe an unknown shape.")
916    if value != -1:
917      raise ValueError(
918          "Received a scalar value '%s' as shape; require a statically known "
919          "scalar with value '-1' to describe an unknown shape." % value)
920    return tensor_shape.unknown_shape()
921
922  shape = tensor.get_shape().with_rank(1)
923  if shape == [0]:
924    return tensor_shape.TensorShape([])
925  elif tensor.op.type == "Cast":
926    pre_cast = constant_value_as_shape(tensor.op.inputs[0])
927    if pre_cast.dims is None:
928      # the input to cast has a totally undefined shape; just return that.
929      return pre_cast
930    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
931    if cast_dtype not in (dtypes.int32, dtypes.int64):
932      return tensor_shape.unknown_shape(shape.dims[0].value)
933    dest_dtype_shape_array = np.array(
934        [x if x is not None else -1 for x in pre_cast.as_list()]).astype(
935            cast_dtype.as_numpy_dtype)
936    return tensor_shape.TensorShape([
937        x if x >= 0 else None
938        for x in dest_dtype_shape_array])
939  elif tensor.op.type == "Shape":
940    return tensor.op.inputs[0].get_shape()
941  elif tensor.op.type == "Pack":
942    ret = tensor_shape.TensorShape([])  # Empty list.
943    # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
944    # would not be rank 1.
945    assert tensor.op.get_attr("axis") == 0
946    for pack_input in tensor.op.inputs:
947      # `pack_input` must be a scalar. Attempt to evaluate it, and append it
948      # to `ret`.
949      pack_input_val = constant_value(pack_input)
950      if pack_input_val is None or pack_input_val < 0:
951        new_dim = tensor_shape.Dimension(None)
952      else:
953        new_dim = tensor_shape.Dimension(pack_input_val)
954      ret = ret.concatenate([new_dim])
955    return ret
956  elif tensor.op.type == "Concat":
957    # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
958    # the only legal value when concatenating vectors, and it will
959    # have been checked by a previous shape function.
960    ret = tensor_shape.TensorShape([])  # Empty list.
961    for concat_input in tensor.op.inputs[1:]:
962      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
963      # and concatenate it with `ret`.
964      ret = ret.concatenate(constant_value_as_shape(concat_input))
965    return ret
966  elif tensor.op.type == "ConcatV2":
967    # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
968    # the only legal value when concatenating vectors, and it will
969    # have been checked by a previous shape function.
970    ret = tensor_shape.TensorShape([])  # Empty list.
971    for concat_input in tensor.op.inputs[:-1]:
972      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
973      # and concatenate it with `ret`.
974      ret = ret.concatenate(constant_value_as_shape(concat_input))
975    return ret
976  elif tensor.op.type == "StridedSlice":
977    try:
978      begin = constant_value(tensor.op.inputs[1])
979      end = constant_value(tensor.op.inputs[2])
980      strides = constant_value(tensor.op.inputs[3])
981      if begin is not None and end is not None and strides is not None:
982        begin = begin[0]
983        end = end[0]
984        strides = strides[0]
985        begin_mask = tensor.op.get_attr("begin_mask")
986        if begin_mask == 1:
987          begin = None
988        end_mask = tensor.op.get_attr("end_mask")
989        if end_mask == 1:
990          end = None
991
992        ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
993        new_axis_mask = tensor.op.get_attr("new_axis_mask")
994        shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
995        valid_attributes = (not ellipsis_mask and not new_axis_mask and
996                            not shrink_axis_mask and (not begin_mask or
997                                                      (begin_mask == 1)) and
998                            (not end_mask or (end_mask == 1)))
999        if valid_attributes:  # additional inputs not supported
1000          prev = constant_value_as_shape(tensor.op.inputs[0])
1001          prev = prev[begin:end:strides]
1002          ret = tensor_shape.TensorShape(prev)
1003          return ret
1004
1005    except ValueError:  # Could come from get_attr or slicing prev.
1006      pass
1007    except TypeError:  # Could come from slicing prev.
1008      pass
1009  elif (tensor.op.type == "Placeholder" and
1010        tensor.op.graph.building_function and
1011        hasattr(tensor.op.graph, "internal_captures")):
1012    # If we are inside a FuncGraph try to lookup the constant value of the
1013    # corresponding external capture. Note that we only look at captures and
1014    # not the fed inputs because those can be fed different values in different
1015    # instantiations of the function call or different iterations of a
1016    # tf.while_loop.
1017    for i, capture in enumerate(tensor.op.graph.internal_captures):
1018      if capture is tensor:
1019        external_capture = tensor.op.graph.external_captures[i]
1020        return constant_value_as_shape(external_capture)
1021
1022  ret = tensor_shape.unknown_shape(shape.dims[0].value)
1023  value = constant_value(tensor)
1024  if value is not None:
1025    ret = ret.merge_with(
1026        tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
1027  return ret
1028
1029
1030# TODO(mdan): Deprecate in favor of more static-friendly types.
1031@tf_export("is_tensor")
1032def is_tf_type(x):  # pylint: disable=invalid-name
1033  """Checks whether `x` is a TF-native type that can be passed to many TF ops.
1034
1035  Use `is_tensor` to differentiate types that can ingested by TensorFlow ops
1036  without any conversion (e.g., `tf.Tensor`, `tf.SparseTensor`, and
1037  `tf.RaggedTensor`) from types that need to be converted into tensors before
1038  they are ingested (e.g., numpy `ndarray` and Python scalars).
1039
1040  For example, in the following code block:
1041
1042  ```python
1043  if not tf.is_tensor(t):
1044    t = tf.convert_to_tensor(t)
1045  return t.shape, t.dtype
1046  ```
1047
1048  we check to make sure that `t` is a tensor (and convert it if not) before
1049  accessing its `shape` and `dtype`.  (But note that not all TensorFlow native
1050  types have shapes or dtypes; `tf.data.Dataset` is an example of a TensorFlow
1051  native type that has neither shape nor dtype.)
1052
1053  Args:
1054    x: A python object to check.
1055
1056  Returns:
1057    `True` if `x` is a TensorFlow-native type.
1058  """
1059  return (isinstance(x, internal.NativeObject) or
1060          isinstance(x, core.Tensor) or
1061          getattr(x, "is_tensor_like", False))
1062
1063
1064# Deprecated alias for tensor_util.is_tf_type.
1065is_tensor = is_tf_type
1066
1067
1068def shape_tensor(shape):  # pylint: disable=invalid-name
1069  """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
1070  dtype = None
1071  if isinstance(shape, (tuple, list)):
1072    if not shape:
1073      dtype = dtypes.int32
1074    else:
1075      # If there are Dimension objects in the shape, unwrap them. This can be a
1076      # problem if v1 and v2 TensorShape objects get mixed up in partial
1077      # conversions, leading to shapes such as (1, 2, Dimension(5)), which are
1078      # not convertible to Tensors because of mixed content.
1079      shape = tuple(map(tensor_shape.dimension_value, shape))
1080  return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
1081
1082
1083# DO NOT USE: For testing only.
1084_ENABLE_MAYBE_SET_STATIC_SHAPE = True
1085
1086
1087def maybe_set_static_shape(tensor, shape):  # pylint: disable=invalid-name
1088  """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
1089
1090  This is a temporary workaround to fix shape inference across functional op
1091  boundaries. E.g.
1092
1093  ```python
1094  shape = tf.constant([3])
1095  @tf.function
1096  def f():
1097    u = tf.random_uniform(shape)
1098    return u
1099  ```
1100
1101  If we were to rely solely on C++ shape inference, the shape of `u` inside
1102  `f` would be unknown because C++ shape inference is not aware of the outer
1103  graph and all it sees is a Placeholder node when backtracing the captured
1104  tensor for `shape`. `maybe_set_static_shape` computes the static shape value
1105  of `shape` by traversing the `FuncGraph` boundaries and sets the correct
1106  shape.
1107
1108  A longer term solution would be to fix C++ shape inference.
1109
1110  Args:
1111    tensor: A tensor.
1112    shape: A shape tensor.
1113  """
1114  if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
1115      ops.get_default_graph().building_function and
1116      not tensor.shape.is_fully_defined() and is_tensor(shape)):
1117    shape = shape_tensor(shape)
1118    const_shape = constant_value_as_shape(shape)
1119    tensor.set_shape(const_shape)
1120