1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include "tools/converter/parser/onnx/onnx_swin_attention_score_parser.h"
23 #include "tools/converter/ops/ops_def.h"
24
25 namespace mindspore {
26 namespace lite {
27 namespace {
28 constexpr auto kKeepProb = "keep_prob";
29 constexpr auto kQueryTranspose = "query_transpose";
30 constexpr auto kKeyTranspose = "key_transpose";
31 constexpr auto kBmmScoreTransposeA = "bmm_score_transpose_a";
32 constexpr auto kBmmScoreTransposeB = "bmm_score_transpose_b";
33 constexpr auto kSoftmaxAxes = "softmax_axes";
34
35 enum AttrDataType { FLOAT, BOOL, LIST_INT };
36 } // namespace
37
Parse(const onnx::GraphProto & onnx_graph,const onnx::NodeProto & onnx_node)38 PrimitiveCPtr OnnxSwinAttentionScoreParser::Parse(const onnx::GraphProto &onnx_graph,
39 const onnx::NodeProto &onnx_node) {
40 auto prim = std::make_unique<SwinAttentionScore>();
41 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
42 std::map<std::string, AttrDataType> attr_map = {
43 {kKeepProb, AttrDataType::FLOAT}, {kQueryTranspose, AttrDataType::BOOL},
44 {kKeyTranspose, AttrDataType::BOOL}, {kBmmScoreTransposeA, AttrDataType::BOOL},
45 {kBmmScoreTransposeB, AttrDataType::BOOL}, {kSoftmaxAxes, AttrDataType::LIST_INT}};
46 for (const auto &onnx_node_attr : onnx_node.attribute()) {
47 const auto &attribute_name = onnx_node_attr.name();
48 auto attr_data_type = attr_map.find(attribute_name);
49 if (attr_data_type != attr_map.end()) {
50 std::vector<int64_t> softmax_axes;
51 switch (attr_data_type->second) {
52 case AttrDataType::FLOAT:
53 prim->AddAttr(attr_data_type->first, MakeValue(static_cast<float>(onnx_node_attr.f())));
54 break;
55 case AttrDataType::BOOL:
56 prim->AddAttr(attr_data_type->first, MakeValue(static_cast<bool>(onnx_node_attr.i())));
57 break;
58 case AttrDataType::LIST_INT:
59 softmax_axes.resize(onnx_node_attr.ints_size());
60 std::copy(onnx_node_attr.ints().begin(), onnx_node_attr.ints().end(), softmax_axes.begin());
61 prim->AddAttr(kSoftmaxAxes, MakeValue(softmax_axes));
62 break;
63 default:
64 MS_LOG(ERROR) << "Unexpected Attributes Data Type[" << attr_data_type->second << "] from "
65 << attr_data_type->first;
66 return nullptr;
67 }
68 }
69 }
70 return prim;
71 }
72
73 OnnxNodeRegistrar g_onnxSwinAttentionScoreParser("SwinAttentionScore", new OnnxSwinAttentionScoreParser());
74 } // namespace lite
75 } // namespace mindspore
76