1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "tensor_desc.h"
17 #include "validation.h"
18 #include "log.h"
19
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 const uint32_t BIT8_TO_BYTE = 1;
23 const uint32_t BIT16_TO_BYTE = 2;
24 const uint32_t BIT32_TO_BYTE = 4;
25 const uint32_t BIT64_TO_BYTE = 8;
26 const size_t SHAPE_MAX_NUM = 10;
27
GetTypeSize(OH_NN_DataType type)28 uint32_t GetTypeSize(OH_NN_DataType type)
29 {
30 switch (type) {
31 case OH_NN_BOOL:
32 return sizeof(bool);
33 case OH_NN_INT8:
34 case OH_NN_UINT8:
35 return BIT8_TO_BYTE;
36 case OH_NN_INT16:
37 case OH_NN_UINT16:
38 case OH_NN_FLOAT16:
39 return BIT16_TO_BYTE;
40 case OH_NN_INT32:
41 case OH_NN_UINT32:
42 case OH_NN_FLOAT32:
43 return BIT32_TO_BYTE;
44 case OH_NN_INT64:
45 case OH_NN_UINT64:
46 case OH_NN_FLOAT64:
47 return BIT64_TO_BYTE;
48 default:
49 return 0;
50 }
51 }
52
GetDataType(OH_NN_DataType * dataType) const53 OH_NN_ReturnCode TensorDesc::GetDataType(OH_NN_DataType* dataType) const
54 {
55 if (dataType == nullptr) {
56 LOGE("GetDataType failed, dataType is nullptr.");
57 return OH_NN_INVALID_PARAMETER;
58 }
59 *dataType = m_dataType;
60 return OH_NN_SUCCESS;
61 }
62
SetDataType(OH_NN_DataType dataType)63 OH_NN_ReturnCode TensorDesc::SetDataType(OH_NN_DataType dataType)
64 {
65 if (!Validation::ValidateTensorDataType(dataType)) {
66 LOGE("TensorDesc::SetDataType failed, dataType %{public}d is invalid.", static_cast<int>(dataType));
67 return OH_NN_INVALID_PARAMETER;
68 }
69 m_dataType = dataType;
70 return OH_NN_SUCCESS;
71 }
72
GetFormat(OH_NN_Format * format) const73 OH_NN_ReturnCode TensorDesc::GetFormat(OH_NN_Format* format) const
74 {
75 if (format == nullptr) {
76 LOGE("GetFormat failed, format is nullptr.");
77 return OH_NN_INVALID_PARAMETER;
78 }
79 *format = m_format;
80 return OH_NN_SUCCESS;
81 }
82
SetFormat(OH_NN_Format format)83 OH_NN_ReturnCode TensorDesc::SetFormat(OH_NN_Format format)
84 {
85 if (!Validation::ValidateTensorFormat(format)) {
86 LOGE("TensorDesc::SetFormat failed, format %{public}d is invalid.", static_cast<int>(format));
87 return OH_NN_INVALID_PARAMETER;
88 }
89 m_format = format;
90 return OH_NN_SUCCESS;
91 }
92
GetShape(int32_t ** shape,size_t * shapeNum) const93 OH_NN_ReturnCode TensorDesc::GetShape(int32_t** shape, size_t* shapeNum) const
94 {
95 if (shape == nullptr) {
96 LOGE("GetShape failed, shape is nullptr.");
97 return OH_NN_INVALID_PARAMETER;
98 }
99 if (*shape != nullptr) {
100 LOGE("GetShape failed, *shape is not nullptr.");
101 return OH_NN_INVALID_PARAMETER;
102 }
103 if (shapeNum == nullptr) {
104 LOGE("GetShape failed, shapeNum is nullptr.");
105 return OH_NN_INVALID_PARAMETER;
106 }
107 *shape = const_cast<int32_t*>(m_shape.data());
108 *shapeNum = m_shape.size();
109 return OH_NN_SUCCESS;
110 }
111
SetShape(const int32_t * shape,size_t shapeNum)112 OH_NN_ReturnCode TensorDesc::SetShape(const int32_t* shape, size_t shapeNum)
113 {
114 if (shape == nullptr) {
115 LOGE("SetShape failed, shape is nullptr.");
116 return OH_NN_INVALID_PARAMETER;
117 }
118 if (shapeNum == 0 || shapeNum > SHAPE_MAX_NUM) {
119 LOGE("SetShape failed, shapeNum is 0 or greater than 10.");
120 return OH_NN_INVALID_PARAMETER;
121 }
122
123 m_shape.clear();
124 for (size_t i = 0; i < shapeNum; ++i) {
125 m_shape.emplace_back(shape[i]);
126 }
127 return OH_NN_SUCCESS;
128 }
129
GetElementNum(size_t * elementNum) const130 OH_NN_ReturnCode TensorDesc::GetElementNum(size_t* elementNum) const
131 {
132 if (elementNum == nullptr) {
133 LOGE("GetElementNum failed, elementNum is nullptr.");
134 return OH_NN_INVALID_PARAMETER;
135 }
136 if (m_shape.empty()) {
137 LOGE("GetElementNum failed, shape is empty.");
138 return OH_NN_INVALID_PARAMETER;
139 }
140 *elementNum = 1;
141 size_t shapeNum = m_shape.size();
142 for (size_t i = 0; i < shapeNum; ++i) {
143 if (m_shape[i] <= 0) {
144 LOGW("GetElementNum return 0 with dynamic shape, shape[%{public}zu] is %{public}d.", i, m_shape[i]);
145 *elementNum = 0;
146 return OH_NN_DYNAMIC_SHAPE;
147 }
148 (*elementNum) *= m_shape[i];
149 }
150 return OH_NN_SUCCESS;
151 }
152
GetByteSize(size_t * byteSize) const153 OH_NN_ReturnCode TensorDesc::GetByteSize(size_t* byteSize) const
154 {
155 if (byteSize == nullptr) {
156 LOGE("GetByteSize failed, byteSize is nullptr.");
157 return OH_NN_INVALID_PARAMETER;
158 }
159 *byteSize = 0;
160 size_t elementNum = 0;
161 auto ret = GetElementNum(&elementNum);
162 if (ret == OH_NN_DYNAMIC_SHAPE) {
163 return OH_NN_SUCCESS;
164 } else if (ret != OH_NN_SUCCESS) {
165 LOGE("GetByteSize failed, get element num failed.");
166 return ret;
167 }
168
169 uint32_t typeSize = GetTypeSize(m_dataType);
170 if (typeSize == 0) {
171 LOGE("GetByteSize failed, data type is invalid.");
172 return OH_NN_INVALID_PARAMETER;
173 }
174
175 *byteSize = elementNum * typeSize;
176
177 return OH_NN_SUCCESS;
178 }
179
SetName(const char * name)180 OH_NN_ReturnCode TensorDesc::SetName(const char* name)
181 {
182 if (name == nullptr) {
183 LOGE("SetName failed, name is nullptr.");
184 return OH_NN_INVALID_PARAMETER;
185 }
186 m_name = name;
187 return OH_NN_SUCCESS;
188 }
189
190 // *name will be invalid after TensorDesc is destroyed
GetName(const char ** name) const191 OH_NN_ReturnCode TensorDesc::GetName(const char** name) const
192 {
193 if (name == nullptr) {
194 LOGE("GetName failed, name is nullptr.");
195 return OH_NN_INVALID_PARAMETER;
196 }
197 if (*name != nullptr) {
198 LOGE("GetName failed, *name is not nullptr.");
199 return OH_NN_INVALID_PARAMETER;
200 }
201 *name = m_name.c_str();
202 return OH_NN_SUCCESS;
203 }
204 } // namespace NeuralNetworkRuntime
205 } // namespace OHOS