1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""XLA LiteralProto utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as _np # Avoids becoming a part of public Tensorflow API. 22 23from tensorflow.compiler.xla import xla_data_pb2 24from tensorflow.compiler.xla.python_api import types 25from tensorflow.compiler.xla.python_api import xla_shape 26 27 28def ConvertLiteralToNumpyArray(literal): 29 """Converts a XLA literal to a Numpy array.""" 30 element_type = literal.shape.element_type 31 if element_type == xla_data_pb2.TUPLE: 32 return tuple( 33 ConvertLiteralToNumpyArray(subliteral) 34 for subliteral in literal.tuple_literals) 35 36 type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] 37 if not literal.shape.dimensions: 38 return _np.array( 39 getattr(literal, type_record.literal_field_name)[0], 40 type_record.numpy_dtype) 41 else: 42 # Infer the proper Numpy order from the LiteralProto's layout. The repeated 43 # field representing the array's content in the Literal is linearized. 44 # Reading is done in two steps: 45 # 46 # 1. Read the array as 1D from the LiteralProto repeated field. 47 # 2. Reshape the array to its proper shape, using the right order depending 48 # on the LiteralProto's layout. 49 layout_order = literal.shape.layout.minor_to_major 50 numpy_shape = tuple(literal.shape.dimensions) 51 if layout_order == range(len(literal.shape.dimensions)): 52 numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F') 53 elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1): 54 numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') 55 else: 56 raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) 57 ndarray = _np.array( 58 getattr(literal, type_record.literal_field_name), 59 copy=False, 60 dtype=type_record.numpy_dtype) 61 return numpy_reshaper(ndarray) 62 63 64def _ConvertNumpyArrayToLiteral(ndarray): 65 """Converts a Numpy array to a XLA literal.""" 66 type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)] 67 literal = xla_data_pb2.LiteralProto() 68 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message) 69 70 if ndarray.ndim == 0: 71 getattr(literal, type_record.literal_field_name).append( 72 ndarray.astype(type_record.literal_field_type).item()) 73 else: 74 # Ndarrays with boolean dtypes need special type conversion with protobufs 75 if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: 76 for element in _np.nditer(ndarray): 77 getattr(literal, type_record.literal_field_name).append( 78 type_record.literal_field_type(element)) 79 else: 80 ndarray_flat = ndarray.ravel(order='A') 81 getattr(literal, type_record.literal_field_name).extend(ndarray_flat) 82 return literal 83 84 85def ConvertNumpyArrayToLiteral(value): 86 """Converts a Numpy array or a nested tuple thereof to an XLA literal.""" 87 if isinstance(value, tuple): 88 literal = xla_data_pb2.LiteralProto() 89 literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message) 90 for component in value: 91 component_literal = literal.tuple_literals.add() 92 component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component)) 93 return literal 94 else: 95 return _ConvertNumpyArrayToLiteral(value) 96