• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2020 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3"""
4This file contains the custom python implementation for Arm NN Tensor objects.
5"""
6import numpy as np
7
8from .._generated.pyarmnn import Tensor as annTensor, TensorInfo, DataType_QAsymmU8, DataType_QSymmS8, \
9    DataType_QAsymmS8, DataType_Float32, DataType_QSymmS16, DataType_Signed32, DataType_Float16
10
11
12class Tensor(annTensor):
13    """Creates a PyArmNN Tensor object.
14
15    This class overrides the swig generated Tensor class. The aim of
16    this is to create an easy to use public api for the Tensor object.
17
18    Memory is allocated and managed by this class, avoiding the need to manage
19    a separate memory area for the tensor compared to the swig generated api.
20
21    """
22
23    def __init__(self, *args):
24        """ Create Tensor object.
25
26        Supported tensor data types:
27            `DataType_QAsymmU8`,
28            `DataType_QAsymmS8`,
29            `DataType_QSymmS16`,
30            `DataType_QSymmS8`,
31            `DataType_Signed32`,
32            `DataType_Float32`,
33            `DataType_Float16`
34
35        Examples:
36            Create an empty tensor
37            >>> import pyarmnn as ann
38            >>> ann.Tensor()
39
40            Create tensor given tensor information
41            >>> ann.Tensor(ann.TensorInfo(...))
42
43            Create tensor from another tensor i.e. copy a tensor
44            >>> ann.Tensor(ann.Tensor())
45
46        Args:
47            tensor(Tensor, optional): Create Tensor from a Tensor i.e. copy.
48            tensor_info (TensorInfo, optional): Tensor information.
49
50        Raises:
51            TypeError: unsupported input data type.
52            ValueError: appropriate constructor could not be found with provided arguments.
53
54        """
55        self.__memory_area = None
56
57        # TensorInfo as first argument, we need to create memory area manually
58        if len(args) > 0 and isinstance(args[0], TensorInfo):
59            self.__create_memory_area(args[0].GetDataType(), args[0].GetNumElements())
60            super().__init__(args[0], self.__memory_area.data)
61
62        # copy constructor - reference to memory area is passed from copied tensor
63        # and armnn's copy constructor is called
64        elif len(args) > 0 and isinstance(args[0], Tensor):
65            self.__memory_area = args[0].get_memory_area()
66            super().__init__(args[0])
67
68        # empty constructor
69        elif len(args) == 0:
70            super().__init__()
71
72        else:
73            raise ValueError('Incorrect number of arguments or type of arguments provided to create Tensor.')
74
75    def __copy__(self) -> 'Tensor':
76        """ Make copy of a tensor.
77
78        Make tensor copyable using the python copy operation.
79
80        Note:
81            The tensor memory area is NOT copied. Instead, the new tensor maintains a
82            reference to the same memory area as the old tensor.
83
84        Example:
85            Copy empty tensor
86            >>> from copy import copy
87            >>> import pyarmnn as ann
88            >>> tensor = ann.Tensor()
89            >>> copied_tensor = copy(tensor)
90
91        Returns:
92            Tensor: a copy of the tensor object provided.
93
94        """
95        return Tensor(self)
96
97    def __create_memory_area(self, data_type: int, num_elements: int):
98        """ Create the memory area used by the tensor to output its results.
99
100        Args:
101            data_type (int): The type of data that will be stored in the memory area.
102                             See DataType_*.
103            num_elements (int): Determines the size of the memory area that will be created.
104
105        """
106        np_data_type_mapping = {DataType_QAsymmU8: np.uint8,
107                                DataType_QAsymmS8: np.int8,
108                                DataType_QSymmS8: np.int8,
109                                DataType_Float32: np.float32,
110                                DataType_QSymmS16: np.int16,
111                                DataType_Signed32: np.int32,
112                                DataType_Float16: np.float16}
113
114        if data_type not in np_data_type_mapping:
115            raise ValueError("The data type provided for this Tensor is not supported.")
116
117        self.__memory_area = np.empty(shape=(num_elements,), dtype=np_data_type_mapping[data_type])
118
119    def get_memory_area(self) -> np.ndarray:
120        """ Get values that are stored by the tensor.
121
122        Returns:
123            ndarray : Tensor data (as numpy array).
124
125        """
126        return self.__memory_area
127