1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/adapter/acl/mapper/random_normal_mapper.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
25 #include "src/common/log_util.h"
26 #include "ops/op_utils.h"
27 #include "ops/standard_normal.h"
28
29 namespace mindspore {
30 namespace lite {
31 namespace {
32 constexpr size_t kRandomNormalMaxInputSize = 2;
33 constexpr size_t kRandomNormalShapeLikeInputIndex = 1;
34 } // namespace
Mapper(const CNodePtr & cnode)35 STATUS RandomNormalMapper::Mapper(const CNodePtr &cnode) {
36 ValueNodePtr value_node = nullptr;
37 PrimitivePtr src_prim = nullptr;
38 if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
39 MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
40 return lite::RET_ERROR;
41 }
42 auto func_graph = cnode->func_graph();
43 if (func_graph == nullptr) {
44 MS_LOG(ERROR) << "FuncGraph of node " << cnode->fullname_with_scope() << " is nullptr";
45 return lite::RET_ERROR;
46 }
47 auto manager = func_graph->manager();
48 if (manager == nullptr) {
49 MS_LOG(ERROR) << "FuncGraphManager of node " << cnode->fullname_with_scope() << " is nullptr";
50 return lite::RET_ERROR;
51 }
52 ops::RandomNormal random_normal(src_prim);
53 float mean = 0;
54 float scale = 1;
55 ops::StandardNormal standard_normal_op;
56 if (random_normal.HasAttr(ops::kSeed)) {
57 auto seed = random_normal.get_seed();
58 standard_normal_op.set_seed(seed);
59 standard_normal_op.set_seed2(seed);
60 }
61 if (random_normal.HasAttr(ops::kMean)) {
62 mean = random_normal.get_mean();
63 }
64 if (random_normal.HasAttr(ops::kScale)) {
65 scale = random_normal.get_scale();
66 }
67 TypeId type_id = kNumberTypeFloat32;
68 if (src_prim->HasAttr(ops::kDataType)) {
69 type_id = static_cast<TypeId>(GetValue<int64_t>(src_prim->GetAttr(ops::kDataType)));
70 }
71 if (cnode->size() > kRandomNormalShapeLikeInputIndex) {
72 auto shape_like_node = cnode->input(kRandomNormalShapeLikeInputIndex);
73 auto shape_node = NewCNode(cnode, prim::kPrimShape, {shape_like_node}, {}, kNumberTypeInt32,
74 cnode->fullname_with_scope() + "_shape");
75 if (!shape_node) {
76 MS_LOG(ERROR) << "Failed to create shape input for node " << cnode->fullname_with_scope();
77 return RET_ERROR;
78 }
79 manager->SetEdge(cnode, kRandomNormalShapeLikeInputIndex, shape_node);
80 } else if (src_prim->HasAttr(ops::kShape)) {
81 auto shape = GetValue<ShapeVector>(src_prim->GetAttr(ops::kShape));
82 std::vector<int32_t> shape_int32;
83 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int32),
84 [](auto &dim) { return static_cast<int32_t>(dim); });
85 auto shape_node = opt::BuildIntVecParameterNode(func_graph, shape_int32, cnode->fullname_with_scope() + "_shape");
86 if (!shape_node) {
87 MS_LOG(ERROR) << "Failed to create shape input for node " << cnode->fullname_with_scope();
88 return RET_ERROR;
89 }
90 manager->AddEdge(cnode, shape_node);
91 } else {
92 MS_LOG(ERROR) << "RandomNormal node does not has attribute shape or shape input, node "
93 << cnode->fullname_with_scope();
94 return RET_ERROR;
95 }
96 auto dst_prim = standard_normal_op.GetPrim();
97 dst_prim->set_attr(ops::kOutputDType, TypeIdToType(type_id));
98 value_node->set_value(dst_prim);
99 CNodePtr cur_node = cnode;
100 if (scale != 1) {
101 auto scale_param = opt::BuildFloatValueParameterNode(func_graph, scale, cnode->fullname_with_scope() + "_scale");
102 if (scale_param == nullptr) {
103 MS_LOG(ERROR) << "Failed to create scale parameter for node " << cnode->fullname_with_scope();
104 return RET_ERROR;
105 }
106 auto mul_node = NewCNode(cnode, prim::kPrimMul, {cnode, scale_param}, cnode->abstract()->Clone(),
107 cnode->fullname_with_scope() + "_scale");
108 if (mul_node == nullptr) {
109 MS_LOG(ERROR) << "Failed to create scale node for node " << cnode->fullname_with_scope();
110 return RET_ERROR;
111 }
112 cur_node = mul_node;
113 }
114 if (mean != 0) {
115 auto mean_param = opt::BuildFloatValueParameterNode(func_graph, mean, cnode->fullname_with_scope() + "_mean");
116 if (mean_param == nullptr) {
117 MS_LOG(ERROR) << "Failed to create mean parameter of node " << cnode->fullname_with_scope();
118 return RET_ERROR;
119 }
120 auto add_node = NewCNode(cnode, prim::kPrimAdd, {cnode, mean_param}, cnode->abstract()->Clone(),
121 cnode->fullname_with_scope() + "_mean");
122 if (add_node == nullptr) {
123 MS_LOG(ERROR) << "Failed to create mean node for node " << cnode->fullname_with_scope();
124 return RET_ERROR;
125 }
126 cur_node = add_node;
127 }
128 if (cur_node != cnode) {
129 manager->Replace(cnode, cur_node);
130 }
131 return RET_OK;
132 }
133
134 REGISTER_PRIMITIVE_MAPPER(kNameRandomNormal, RandomNormalMapper)
135 } // namespace lite
136 } // namespace mindspore
137