• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "tensor_desc.h"
17 #include "validation.h"
18 #include "log.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 const uint32_t BIT8_TO_BYTE = 1;
23 const uint32_t BIT16_TO_BYTE = 2;
24 const uint32_t BIT32_TO_BYTE = 4;
25 const uint32_t BIT64_TO_BYTE = 8;
26 const size_t SHAPE_MAX_NUM = 10;
27 
GetTypeSize(OH_NN_DataType type)28 uint32_t GetTypeSize(OH_NN_DataType type)
29 {
30     switch (type) {
31         case OH_NN_BOOL:
32             return sizeof(bool);
33         case OH_NN_INT8:
34         case OH_NN_UINT8:
35             return BIT8_TO_BYTE;
36         case OH_NN_INT16:
37         case OH_NN_UINT16:
38         case OH_NN_FLOAT16:
39             return BIT16_TO_BYTE;
40         case OH_NN_INT32:
41         case OH_NN_UINT32:
42         case OH_NN_FLOAT32:
43             return BIT32_TO_BYTE;
44         case OH_NN_INT64:
45         case OH_NN_UINT64:
46         case OH_NN_FLOAT64:
47             return BIT64_TO_BYTE;
48         default:
49             return 0;
50     }
51 }
52 
GetDataType(OH_NN_DataType * dataType) const53 OH_NN_ReturnCode TensorDesc::GetDataType(OH_NN_DataType* dataType) const
54 {
55     if (dataType == nullptr) {
56         LOGE("GetDataType failed, dataType is nullptr.");
57         return OH_NN_INVALID_PARAMETER;
58     }
59     *dataType = m_dataType;
60     return OH_NN_SUCCESS;
61 }
62 
SetDataType(OH_NN_DataType dataType)63 OH_NN_ReturnCode TensorDesc::SetDataType(OH_NN_DataType dataType)
64 {
65     if (!Validation::ValidateTensorDataType(dataType)) {
66         LOGE("TensorDesc::SetDataType failed, dataType %{public}d is invalid.", static_cast<int>(dataType));
67         return OH_NN_INVALID_PARAMETER;
68     }
69     m_dataType = dataType;
70     return OH_NN_SUCCESS;
71 }
72 
GetFormat(OH_NN_Format * format) const73 OH_NN_ReturnCode TensorDesc::GetFormat(OH_NN_Format* format) const
74 {
75     if (format == nullptr) {
76         LOGE("GetFormat failed, format is nullptr.");
77         return OH_NN_INVALID_PARAMETER;
78     }
79     *format = m_format;
80     return OH_NN_SUCCESS;
81 }
82 
SetFormat(OH_NN_Format format)83 OH_NN_ReturnCode TensorDesc::SetFormat(OH_NN_Format format)
84 {
85     if (!Validation::ValidateTensorFormat(format)) {
86         LOGE("TensorDesc::SetFormat failed, format %{public}d is invalid.", static_cast<int>(format));
87         return OH_NN_INVALID_PARAMETER;
88     }
89     m_format = format;
90     return OH_NN_SUCCESS;
91 }
92 
GetShape(int32_t ** shape,size_t * shapeNum) const93 OH_NN_ReturnCode TensorDesc::GetShape(int32_t** shape, size_t* shapeNum) const
94 {
95     if (shape == nullptr) {
96         LOGE("GetShape failed, shape is nullptr.");
97         return OH_NN_INVALID_PARAMETER;
98     }
99     if (*shape != nullptr) {
100         LOGE("GetShape failed, *shape is not nullptr.");
101         return OH_NN_INVALID_PARAMETER;
102     }
103     if (shapeNum == nullptr) {
104         LOGE("GetShape failed, shapeNum is nullptr.");
105         return OH_NN_INVALID_PARAMETER;
106     }
107     *shape = const_cast<int32_t*>(m_shape.data());
108     *shapeNum = m_shape.size();
109     return OH_NN_SUCCESS;
110 }
111 
SetShape(const int32_t * shape,size_t shapeNum)112 OH_NN_ReturnCode TensorDesc::SetShape(const int32_t* shape, size_t shapeNum)
113 {
114     if (shape == nullptr) {
115         LOGE("SetShape failed, shape is nullptr.");
116         return OH_NN_INVALID_PARAMETER;
117     }
118     if (shapeNum == 0 || shapeNum > SHAPE_MAX_NUM) {
119         LOGE("SetShape failed, shapeNum is 0 or greater than 10.");
120         return OH_NN_INVALID_PARAMETER;
121     }
122 
123     m_shape.clear();
124     for (size_t i = 0; i < shapeNum; ++i) {
125         m_shape.emplace_back(shape[i]);
126     }
127     return OH_NN_SUCCESS;
128 }
129 
GetElementNum(size_t * elementNum) const130 OH_NN_ReturnCode TensorDesc::GetElementNum(size_t* elementNum) const
131 {
132     if (elementNum == nullptr) {
133         LOGE("GetElementNum failed, elementNum is nullptr.");
134         return OH_NN_INVALID_PARAMETER;
135     }
136     if (m_shape.empty()) {
137         LOGE("GetElementNum failed, shape is empty.");
138         return OH_NN_INVALID_PARAMETER;
139     }
140     *elementNum = 1;
141     size_t shapeNum = m_shape.size();
142     for (size_t i = 0; i < shapeNum; ++i) {
143         if (m_shape[i] <= 0) {
144             LOGW("GetElementNum return 0 with dynamic shape, shape[%{public}zu] is %{public}d.", i, m_shape[i]);
145             *elementNum = 0;
146             return OH_NN_DYNAMIC_SHAPE;
147         }
148         (*elementNum) *= m_shape[i];
149     }
150     return OH_NN_SUCCESS;
151 }
152 
GetByteSize(size_t * byteSize) const153 OH_NN_ReturnCode TensorDesc::GetByteSize(size_t* byteSize) const
154 {
155     if (byteSize == nullptr) {
156         LOGE("GetByteSize failed, byteSize is nullptr.");
157         return OH_NN_INVALID_PARAMETER;
158     }
159     *byteSize = 0;
160     size_t elementNum = 0;
161     auto ret = GetElementNum(&elementNum);
162     if (ret == OH_NN_DYNAMIC_SHAPE) {
163         return OH_NN_SUCCESS;
164     } else if (ret != OH_NN_SUCCESS) {
165         LOGE("GetByteSize failed, get element num failed.");
166         return ret;
167     }
168 
169     uint32_t typeSize = GetTypeSize(m_dataType);
170     if (typeSize == 0) {
171         LOGE("GetByteSize failed, data type is invalid.");
172         return OH_NN_INVALID_PARAMETER;
173     }
174 
175     *byteSize = elementNum * typeSize;
176 
177     return OH_NN_SUCCESS;
178 }
179 
SetName(const char * name)180 OH_NN_ReturnCode TensorDesc::SetName(const char* name)
181 {
182     if (name == nullptr) {
183         LOGE("SetName failed, name is nullptr.");
184         return OH_NN_INVALID_PARAMETER;
185     }
186     m_name = name;
187     return OH_NN_SUCCESS;
188 }
189 
190 // *name will be invalid after TensorDesc is destroyed
GetName(const char ** name) const191 OH_NN_ReturnCode TensorDesc::GetName(const char** name) const
192 {
193     if (name == nullptr) {
194         LOGE("GetName failed, name is nullptr.");
195         return OH_NN_INVALID_PARAMETER;
196     }
197     if (*name != nullptr) {
198         LOGE("GetName failed, *name is not nullptr.");
199         return OH_NN_INVALID_PARAMETER;
200     }
201     *name = m_name.c_str();
202     return OH_NN_SUCCESS;
203 }
204 }  // namespace NeuralNetworkRuntime
205 }  // namespace OHOS