1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <memory>
19 #include <string>
20
21 #include "extendrt/mindir_loader/mindir_model/mindir_model_util.h"
22 #include "ir/tensor.h"
23 #include "ir/value.h"
24 #include "include/errorcode.h"
25 #include "nnacl/op_base.h"
26 #include "src/common/common.h"
27 #include "src/common/log_util.h"
28
29 namespace mindspore::infer::mindir {
30 static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
31 {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
32 {mind_ir::TensorProto_DataType_INT8, kNumberTypeInt8},
33 {mind_ir::TensorProto_DataType_INT16, kNumberTypeInt16},
34 {mind_ir::TensorProto_DataType_INT32, kNumberTypeInt32},
35 {mind_ir::TensorProto_DataType_INT64, kNumberTypeInt64},
36 {mind_ir::TensorProto_DataType_UINT8, kNumberTypeUInt8},
37 {mind_ir::TensorProto_DataType_UINT16, kNumberTypeUInt16},
38 {mind_ir::TensorProto_DataType_UINT32, kNumberTypeUInt32},
39 {mind_ir::TensorProto_DataType_UINT64, kNumberTypeUInt64},
40 {mind_ir::TensorProto_DataType_FLOAT16, kNumberTypeFloat16},
41 {mind_ir::TensorProto_DataType_FLOAT, kNumberTypeFloat32},
42 {mind_ir::TensorProto_DataType_FLOAT64, kNumberTypeFloat64},
43 {mind_ir::TensorProto_DataType_DOUBLE, kNumberTypeFloat64},
44 {mind_ir::TensorProto_DataType_STRING, kObjectTypeString},
45 {mind_ir::TensorProto_DataType_COMPLEX64, kNumberTypeComplex64},
46 {mind_ir::TensorProto_DataType_COMPLEX128, kNumberTypeComplex128}};
47
MakeValueFromAttribute(const mind_ir::AttributeProto & attr_proto)48 mindspore::ValuePtr MindirModelUtil::MakeValueFromAttribute(const mind_ir::AttributeProto &attr_proto) {
49 switch (attr_proto.type()) {
50 case mind_ir::AttributeProto_AttributeType_TENSORS: {
51 // embed tensor attribute
52 return MindirModelUtil::MakeValueFromTensorOrTypeAttribute(attr_proto);
53 }
54 case mind_ir::AttributeProto_AttributeType_TUPLE:
55 case mind_ir::AttributeProto_AttributeType_LIST: {
56 // list attribute
57 return MindirModelUtil::MakeValueFromListAttribute(attr_proto);
58 }
59 default: {
60 // base scalar attribute
61 return MindirModelUtil::MakeValueFromScalarAttribute(attr_proto);
62 }
63 }
64 }
65
MakeValueFromTensorOrTypeAttribute(const mind_ir::AttributeProto & attr_proto)66 mindspore::ValuePtr MindirModelUtil::MakeValueFromTensorOrTypeAttribute(const mind_ir::AttributeProto &attr_proto) {
67 auto tensor_proto = attr_proto.tensors(0);
68 if (tensor_proto.has_raw_data()) {
69 // For real tensor
70 return MindirModelUtil::MakeValueFromTensorAttribute(tensor_proto);
71 } else {
72 // for data type
73 const int attr_tensor_type = tensor_proto.data_type();
74 auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
75 MS_CHECK_TRUE_MSG(iter == kDefaultValueSwitchMap.end(), nullptr,
76 "MindirModelUtil: Generate value ptr failed, cannot find attr tensor type " << attr_tensor_type);
77 return TypeIdToType(iter->second);
78 }
79 }
80
MakeValueFromTensorAttribute(const mind_ir::TensorProto & tensor_proto,bool need_load_data)81 mindspore::ValuePtr MindirModelUtil::MakeValueFromTensorAttribute(const mind_ir::TensorProto &tensor_proto,
82 bool need_load_data) {
83 ShapeVector shape;
84 auto attr_tensor_type = tensor_proto.data_type();
85 for (int i = 0; i < tensor_proto.dims_size(); i++) {
86 shape.push_back(tensor_proto.dims(i));
87 }
88 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
89
90 MS_EXCEPTION_IF_NULL(tensor);
91 const std::string &tensor_buf = tensor_proto.raw_data();
92 if (tensor_proto.has_raw_data()) {
93 auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
94 auto ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
95 MS_CHECK_TRUE_MSG(
96 ret != mindspore::lite::RET_OK, nullptr,
97 "MindirModelUtil: Generate tensor ptr from tensor proto failed, failed to get tensor from tensor proto.");
98 } else {
99 MS_CHECK_TRUE_MSG(
100 need_load_data, nullptr,
101 "MindirModelUtil: Generate tensor ptr from tensor proto failed, failed to get tensor from tensor proto.");
102 }
103 return tensor;
104 }
105
MakeValueFromListAttribute(const mind_ir::AttributeProto & attr_proto)106 mindspore::ValuePtr MindirModelUtil::MakeValueFromListAttribute(const mind_ir::AttributeProto &attr_proto) {
107 std::vector<mindspore::ValuePtr> vec;
108 for (int i = 0; i < attr_proto.values_size(); i++) {
109 mind_ir::AttributeProto elem_attr_proto = attr_proto.values(i);
110 mindspore::ValuePtr value_ptr = MindirModelUtil::MakeValueFromAttribute(elem_attr_proto);
111 vec.emplace_back(value_ptr);
112 }
113 auto type = attr_proto.type();
114 mindspore::ValuePtr value_sequence;
115 switch (type) {
116 case mind_ir::AttributeProto_AttributeType_TUPLE: {
117 return std::make_shared<mindspore::ValueTuple>(vec);
118 }
119 case mind_ir::AttributeProto_AttributeType_LIST: {
120 return std::make_shared<mindspore::ValueList>(vec);
121 }
122 default: {
123 MS_LOG(ERROR)
124 << "MindirModelUtil: Obtain value in sequence form failed, the attribute type should be tuple or list";
125 return nullptr;
126 }
127 }
128 }
129
MakeValueFromScalarAttribute(const mind_ir::AttributeProto & attr_proto)130 mindspore::ValuePtr MindirModelUtil::MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto) {
131 auto attr_proto_type = static_cast<int>(attr_proto.type());
132 switch (attr_proto_type) {
133 case mind_ir::AttributeProto_AttributeType_STRING: {
134 auto value = static_cast<std::string>(attr_proto.s());
135 return MakeValue<std::string>(value);
136 }
137 case mind_ir::AttributeProto_AttributeType_INT8: {
138 auto value = static_cast<int8_t>(attr_proto.i());
139 return MakeValue<int8_t>(value);
140 }
141 case mind_ir::AttributeProto_AttributeType_INT16: {
142 auto value = static_cast<int16_t>(attr_proto.i());
143 return MakeValue<int16_t>(value);
144 }
145 case mind_ir::AttributeProto_AttributeType_INT32: {
146 auto value = static_cast<int32_t>(attr_proto.i());
147 return MakeValue<int32_t>(value);
148 }
149 case mind_ir::AttributeProto_AttributeType_INT64: {
150 auto value = static_cast<int64_t>(attr_proto.i());
151 return MakeValue<int64_t>(value);
152 }
153 case mind_ir::AttributeProto_AttributeType_UINT8: {
154 auto value = static_cast<uint8_t>(attr_proto.i());
155 return MakeValue<uint8_t>(value);
156 }
157 case mind_ir::AttributeProto_AttributeType_UINT16: {
158 auto value = static_cast<uint16_t>(attr_proto.i());
159 return MakeValue<uint16_t>(value);
160 }
161 case mind_ir::AttributeProto_AttributeType_UINT32: {
162 auto value = static_cast<uint32_t>(attr_proto.i());
163 return MakeValue<uint32_t>(value);
164 }
165 case mind_ir::AttributeProto_AttributeType_UINT64: {
166 auto value = static_cast<uint64_t>(attr_proto.i());
167 return MakeValue<uint64_t>(value);
168 }
169 case mind_ir::AttributeProto_AttributeType_FLOAT: {
170 auto value = static_cast<float>(attr_proto.f());
171 return MakeValue<float>(value);
172 }
173 case mind_ir::AttributeProto_AttributeType_DOUBLE: {
174 auto value = static_cast<double>(attr_proto.d());
175 return MakeValue<double>(value);
176 }
177 case mind_ir::AttributeProto_AttributeType_BOOL: {
178 auto value = static_cast<int32_t>(attr_proto.i());
179 return MakeValue<bool>(value);
180 }
181 default: {
182 MS_LOG(ERROR) << "MindirModelUtil: Obtain cnode attr in single scalar form failed, attr type " << attr_proto_type
183 << " is xinot supported ";
184 return nullptr;
185 }
186 }
187 }
188
ProtoTypeToTypeId(int32_t proto_type)189 mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) {
190 auto it = kDefaultValueSwitchMap.find(proto_type);
191 if (it == kDefaultValueSwitchMap.end()) {
192 return kTypeUnknown;
193 }
194 return it->second;
195 }
196 } // namespace mindspore::infer::mindir
197