• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""XLA LiteralProto utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as _np  # Avoids becoming a part of public Tensorflow API.
22
23from tensorflow.compiler.xla import xla_data_pb2
24from tensorflow.compiler.xla.python_api import types
25from tensorflow.compiler.xla.python_api import xla_shape
26
27
28def ConvertLiteralToNumpyArray(literal):
29  """Converts a XLA literal to a Numpy array."""
30  element_type = literal.shape.element_type
31  if element_type == xla_data_pb2.TUPLE:
32    return tuple(
33        ConvertLiteralToNumpyArray(subliteral)
34        for subliteral in literal.tuple_literals)
35
36  type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type]
37  if not literal.shape.dimensions:
38    return _np.array(
39        getattr(literal, type_record.literal_field_name)[0],
40        type_record.numpy_dtype)
41  else:
42    # Infer the proper Numpy order from the LiteralProto's layout. The repeated
43    # field representing the array's content in the Literal is linearized.
44    # Reading is done in two steps:
45    #
46    # 1. Read the array as 1D from the LiteralProto repeated field.
47    # 2. Reshape the array to its proper shape, using the right order depending
48    #    on the LiteralProto's layout.
49    layout_order = literal.shape.layout.minor_to_major
50    numpy_shape = tuple(literal.shape.dimensions)
51    if layout_order == range(len(literal.shape.dimensions)):
52      numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F')
53    elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1):
54      numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C')
55    else:
56      raise NotImplementedError('Unsupported layout: {0}'.format(layout_order))
57    ndarray = _np.array(
58        getattr(literal, type_record.literal_field_name),
59        copy=False,
60        dtype=type_record.numpy_dtype)
61    return numpy_reshaper(ndarray)
62
63
64def _ConvertNumpyArrayToLiteral(ndarray):
65  """Converts a Numpy array to a XLA literal."""
66  type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)]
67  literal = xla_data_pb2.LiteralProto()
68  literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
69
70  if ndarray.ndim == 0:
71    getattr(literal, type_record.literal_field_name).append(
72        ndarray.astype(type_record.literal_field_type).item())
73  else:
74    # Ndarrays with boolean dtypes need special type conversion with protobufs
75    if ndarray.dtype in {_np.bool_, _np.dtype('bool')}:
76      for element in _np.nditer(ndarray):
77        getattr(literal, type_record.literal_field_name).append(
78            type_record.literal_field_type(element))
79    else:
80      ndarray_flat = ndarray.ravel(order='A')
81      getattr(literal, type_record.literal_field_name).extend(ndarray_flat)
82  return literal
83
84
85def ConvertNumpyArrayToLiteral(value):
86  """Converts a Numpy array or a nested tuple thereof to an XLA literal."""
87  if isinstance(value, tuple):
88    literal = xla_data_pb2.LiteralProto()
89    literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
90    for component in value:
91      component_literal = literal.tuple_literals.add()
92      component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component))
93    return literal
94  else:
95    return _ConvertNumpyArrayToLiteral(value)
96