• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/adapter/acl/mapper/random_normal_mapper.h"
18 #include <memory>
19 #include <vector>
20 #include <algorithm>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/framework_ops.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
25 #include "src/common/log_util.h"
26 #include "ops/op_utils.h"
27 #include "ops/standard_normal.h"
28 
29 namespace mindspore {
30 namespace lite {
31 namespace {
32 constexpr size_t kRandomNormalMaxInputSize = 2;
33 constexpr size_t kRandomNormalShapeLikeInputIndex = 1;
34 }  // namespace
Mapper(const CNodePtr & cnode)35 STATUS RandomNormalMapper::Mapper(const CNodePtr &cnode) {
36   ValueNodePtr value_node = nullptr;
37   PrimitivePtr src_prim = nullptr;
38   if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
39     MS_LOG(ERROR) << "Get value node and primitive from cnode failed.";
40     return lite::RET_ERROR;
41   }
42   auto func_graph = cnode->func_graph();
43   if (func_graph == nullptr) {
44     MS_LOG(ERROR) << "FuncGraph of node " << cnode->fullname_with_scope() << " is nullptr";
45     return lite::RET_ERROR;
46   }
47   auto manager = func_graph->manager();
48   if (manager == nullptr) {
49     MS_LOG(ERROR) << "FuncGraphManager of node " << cnode->fullname_with_scope() << " is nullptr";
50     return lite::RET_ERROR;
51   }
52   ops::RandomNormal random_normal(src_prim);
53   float mean = 0;
54   float scale = 1;
55   ops::StandardNormal standard_normal_op;
56   if (random_normal.HasAttr(ops::kSeed)) {
57     auto seed = random_normal.get_seed();
58     standard_normal_op.set_seed(seed);
59     standard_normal_op.set_seed2(seed);
60   }
61   if (random_normal.HasAttr(ops::kMean)) {
62     mean = random_normal.get_mean();
63   }
64   if (random_normal.HasAttr(ops::kScale)) {
65     scale = random_normal.get_scale();
66   }
67   TypeId type_id = kNumberTypeFloat32;
68   if (src_prim->HasAttr(ops::kDataType)) {
69     type_id = static_cast<TypeId>(GetValue<int64_t>(src_prim->GetAttr(ops::kDataType)));
70   }
71   if (cnode->size() > kRandomNormalShapeLikeInputIndex) {
72     auto shape_like_node = cnode->input(kRandomNormalShapeLikeInputIndex);
73     auto shape_node = NewCNode(cnode, prim::kPrimShape, {shape_like_node}, {}, kNumberTypeInt32,
74                                cnode->fullname_with_scope() + "_shape");
75     if (!shape_node) {
76       MS_LOG(ERROR) << "Failed to create shape input for node " << cnode->fullname_with_scope();
77       return RET_ERROR;
78     }
79     manager->SetEdge(cnode, kRandomNormalShapeLikeInputIndex, shape_node);
80   } else if (src_prim->HasAttr(ops::kShape)) {
81     auto shape = GetValue<ShapeVector>(src_prim->GetAttr(ops::kShape));
82     std::vector<int32_t> shape_int32;
83     std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int32),
84                    [](auto &dim) { return static_cast<int32_t>(dim); });
85     auto shape_node = opt::BuildIntVecParameterNode(func_graph, shape_int32, cnode->fullname_with_scope() + "_shape");
86     if (!shape_node) {
87       MS_LOG(ERROR) << "Failed to create shape input for node " << cnode->fullname_with_scope();
88       return RET_ERROR;
89     }
90     manager->AddEdge(cnode, shape_node);
91   } else {
92     MS_LOG(ERROR) << "RandomNormal node does not has attribute shape or shape input, node "
93                   << cnode->fullname_with_scope();
94     return RET_ERROR;
95   }
96   auto dst_prim = standard_normal_op.GetPrim();
97   dst_prim->set_attr(ops::kOutputDType, TypeIdToType(type_id));
98   value_node->set_value(dst_prim);
99   CNodePtr cur_node = cnode;
100   if (scale != 1) {
101     auto scale_param = opt::BuildFloatValueParameterNode(func_graph, scale, cnode->fullname_with_scope() + "_scale");
102     if (scale_param == nullptr) {
103       MS_LOG(ERROR) << "Failed to create scale parameter for node " << cnode->fullname_with_scope();
104       return RET_ERROR;
105     }
106     auto mul_node = NewCNode(cnode, prim::kPrimMul, {cnode, scale_param}, cnode->abstract()->Clone(),
107                              cnode->fullname_with_scope() + "_scale");
108     if (mul_node == nullptr) {
109       MS_LOG(ERROR) << "Failed to create scale node for node " << cnode->fullname_with_scope();
110       return RET_ERROR;
111     }
112     cur_node = mul_node;
113   }
114   if (mean != 0) {
115     auto mean_param = opt::BuildFloatValueParameterNode(func_graph, mean, cnode->fullname_with_scope() + "_mean");
116     if (mean_param == nullptr) {
117       MS_LOG(ERROR) << "Failed to create mean parameter of node " << cnode->fullname_with_scope();
118       return RET_ERROR;
119     }
120     auto add_node = NewCNode(cnode, prim::kPrimAdd, {cnode, mean_param}, cnode->abstract()->Clone(),
121                              cnode->fullname_with_scope() + "_mean");
122     if (add_node == nullptr) {
123       MS_LOG(ERROR) << "Failed to create mean node for node " << cnode->fullname_with_scope();
124       return RET_ERROR;
125     }
126     cur_node = add_node;
127   }
128   if (cur_node != cnode) {
129     manager->Replace(cnode, cur_node);
130   }
131   return RET_OK;
132 }
133 
134 REGISTER_PRIMITIVE_MAPPER(kNameRandomNormal, RandomNormalMapper)
135 }  // namespace lite
136 }  // namespace mindspore
137