1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "base/base.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/nn_ops.h"
26 #include "mindspore/core/ops/op_def.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "ops/primitive_c.h"
29 #include "utils/anf_utils.h"
30 #include "utils/check_convert_utils.h"
31
32 namespace mindspore::graphkernel {
GetOpIndexInfo()33 const std::unordered_map<std::string, HashSet<size_t>> &ValueDependOpUtils::GetOpIndexInfo() {
34 static const std::unordered_map<std::string, HashSet<size_t>> op_idx_info_ = {
35 {prim::kPrimReshape->name(), {1}},
36 {prim::kPrimReduceMax->name(), {1}},
37 {prim::kPrimExpandDims->name(), {1}},
38 {prim::kPrimReduceMin->name(), {1}},
39 {prim::kPrimReduceSum->name(), {1}},
40 {prim::kPrimTranspose->name(), {1}},
41 {prim::kPrimTile->name(), {1}},
42 {prim::kPrimBroadcastTo->name(), {1}},
43 {prim::kPrimReduceMean->name(), {1}},
44 {prim::kPrimSlice->name(), {1, 2}},
45 {prim::kPrimStridedSlice->name(), {1, 2, 3}},
46 {prim::kPrimOneHot->name(), {1}},
47 {prim::kPrimReduceFusion->name(), {1}},
48 {prim::kPrimConstantOfShape->name(), {0}},
49 {prim::kPrimGather->name(), {2}},
50 {prim::kPrimTupleGetItem->name(), {1}},
51 {prim::kPrimUnsortedSegmentSum->name(), {2}},
52 {prim::kPrimCumSum->name(), {1}}};
53 return op_idx_info_;
54 }
55
IsConstInput(const AnfNodePtr & node)56 bool ValueDependOpUtils::IsConstInput(const AnfNodePtr &node) {
57 auto prim = GetCNodePrimitive(node);
58 if (prim != nullptr) {
59 const auto &op_index_info = GetOpIndexInfo();
60 auto iter = op_index_info.find(prim->name());
61 if (iter != op_index_info.end()) {
62 auto inputs = node->cast<CNodePtr>()->inputs();
63 for (const auto &i : iter->second) {
64 if (i + 1 < inputs.size() && inputs[i + 1] != nullptr) {
65 auto input_node = inputs[i + 1];
66 ValuePtr value = nullptr;
67 if (input_node->isa<ValueNode>()) {
68 auto value_node = input_node->cast<ValueNodePtr>();
69 value = value_node->value();
70 } else if (input_node->isa<Parameter>()) {
71 auto parameter_node = input_node->cast<ParameterPtr>();
72 value = parameter_node->abstract()->BuildValue();
73 }
74 if (value == nullptr) {
75 return false;
76 }
77 if (value->isa<ValueAny>()) {
78 return false;
79 }
80 auto tensor = value->cast<tensor::TensorPtr>();
81 if (tensor != nullptr && tensor->data().const_data() == nullptr) {
82 return false;
83 }
84 }
85 }
86 }
87 }
88 return true;
89 }
90
AddConstInputToAttr(const CNodePtr & cnode,const HashSet<size_t> & input_idx)91 bool ValueDependOpUtils::AddConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx) {
92 auto primitive = GetCNodePrimitive(cnode);
93 MS_EXCEPTION_IF_NULL(primitive);
94 primitive = primitive->Clone();
95 MS_EXCEPTION_IF_NULL(primitive);
96
97 const auto &op_name = primitive->name();
98 auto op_def = mindspore::ops::GetOpDef(op_name);
99 if (op_def == nullptr) {
100 MS_LOG(INFO) << op_name << " not found in op def.";
101 return false;
102 }
103 const auto &input_vec = op_def->args_;
104 auto inputs = cnode->inputs();
105 for (size_t i = 0; i < inputs.size() - 1; ++i) {
106 auto input_node = inputs[i + 1];
107 MS_EXCEPTION_IF_NULL(input_node);
108 if (input_idx.count(i) != 0) {
109 if (i >= input_vec.size()) {
110 MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_vec.size() << "]";
111 return false;
112 }
113 ValuePtr value = nullptr;
114 if (input_node->isa<ValueNode>()) {
115 auto value_node = input_node->cast<ValueNodePtr>();
116 value = value_node->value();
117 } else if (input_node->isa<Parameter>()) {
118 auto parameter_node = input_node->cast<ParameterPtr>();
119 value = parameter_node->abstract()->BuildValue();
120 }
121 if (value == nullptr) {
122 MS_LOG(DEBUG) << input_vec[i].arg_name_ << "'s Value is null.";
123 return false;
124 }
125 if (value->isa<ValueAny>()) {
126 MS_LOG(DEBUG) << input_vec[i].arg_name_ << "'s Value is ValueAny.";
127 return false;
128 }
129 if (!value->isa<tensor::Tensor>()) {
130 primitive->set_attr(input_vec[i].arg_name_, value);
131 continue;
132 }
133 auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(input_vec[i].arg_name_, value, primitive->name());
134 auto tensor = value->cast<tensor::TensorPtr>();
135 auto tensor_shape = tensor->shape_c();
136 if (tensor_shape.empty()) {
137 primitive->set_attr(input_vec[i].arg_name_, MakeValue(value_vector[0]));
138 } else {
139 primitive->set_attr(input_vec[i].arg_name_, MakeValue(value_vector));
140 }
141 }
142 }
143 cnode->set_input(0, std::make_shared<ValueNode>(primitive));
144 return true;
145 }
146
147 } // namespace mindspore::graphkernel
148