• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <memory>
19 #include <algorithm>
20 #include <map>
21 #include <string>
22 #include "tools/converter/parser/onnx/onnx_swin_attention_score_parser.h"
23 #include "tools/converter/ops/ops_def.h"
24 
25 namespace mindspore {
26 namespace lite {
27 namespace {
28 constexpr auto kKeepProb = "keep_prob";
29 constexpr auto kQueryTranspose = "query_transpose";
30 constexpr auto kKeyTranspose = "key_transpose";
31 constexpr auto kBmmScoreTransposeA = "bmm_score_transpose_a";
32 constexpr auto kBmmScoreTransposeB = "bmm_score_transpose_b";
33 constexpr auto kSoftmaxAxes = "softmax_axes";
34 
35 enum AttrDataType { FLOAT, BOOL, LIST_INT };
36 }  // namespace
37 
Parse(const onnx::GraphProto & onnx_graph,const onnx::NodeProto & onnx_node)38 PrimitiveCPtr OnnxSwinAttentionScoreParser::Parse(const onnx::GraphProto &onnx_graph,
39                                                   const onnx::NodeProto &onnx_node) {
40   auto prim = std::make_unique<SwinAttentionScore>();
41   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
42   std::map<std::string, AttrDataType> attr_map = {
43     {kKeepProb, AttrDataType::FLOAT},          {kQueryTranspose, AttrDataType::BOOL},
44     {kKeyTranspose, AttrDataType::BOOL},       {kBmmScoreTransposeA, AttrDataType::BOOL},
45     {kBmmScoreTransposeB, AttrDataType::BOOL}, {kSoftmaxAxes, AttrDataType::LIST_INT}};
46   for (const auto &onnx_node_attr : onnx_node.attribute()) {
47     const auto &attribute_name = onnx_node_attr.name();
48     auto attr_data_type = attr_map.find(attribute_name);
49     if (attr_data_type != attr_map.end()) {
50       std::vector<int64_t> softmax_axes;
51       switch (attr_data_type->second) {
52         case AttrDataType::FLOAT:
53           prim->AddAttr(attr_data_type->first, MakeValue(static_cast<float>(onnx_node_attr.f())));
54           break;
55         case AttrDataType::BOOL:
56           prim->AddAttr(attr_data_type->first, MakeValue(static_cast<bool>(onnx_node_attr.i())));
57           break;
58         case AttrDataType::LIST_INT:
59           softmax_axes.resize(onnx_node_attr.ints_size());
60           std::copy(onnx_node_attr.ints().begin(), onnx_node_attr.ints().end(), softmax_axes.begin());
61           prim->AddAttr(kSoftmaxAxes, MakeValue(softmax_axes));
62           break;
63         default:
64           MS_LOG(ERROR) << "Unexpected Attributes Data Type[" << attr_data_type->second << "] from "
65                         << attr_data_type->first;
66           return nullptr;
67       }
68     }
69   }
70   return prim;
71 }
72 
73 OnnxNodeRegistrar g_onnxSwinAttentionScoreParser("SwinAttentionScore", new OnnxSwinAttentionScoreParser());
74 }  // namespace lite
75 }  // namespace mindspore
76