• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/value_depend_op_utils.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "base/base.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/nn_ops.h"
26 #include "mindspore/core/ops/op_def.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "ops/primitive_c.h"
29 #include "utils/anf_utils.h"
30 #include "utils/check_convert_utils.h"
31 
32 namespace mindspore::graphkernel {
GetOpIndexInfo()33 const std::unordered_map<std::string, HashSet<size_t>> &ValueDependOpUtils::GetOpIndexInfo() {
34   static const std::unordered_map<std::string, HashSet<size_t>> op_idx_info_ = {
35     {prim::kPrimReshape->name(), {1}},
36     {prim::kPrimReduceMax->name(), {1}},
37     {prim::kPrimExpandDims->name(), {1}},
38     {prim::kPrimReduceMin->name(), {1}},
39     {prim::kPrimReduceSum->name(), {1}},
40     {prim::kPrimTranspose->name(), {1}},
41     {prim::kPrimTile->name(), {1}},
42     {prim::kPrimBroadcastTo->name(), {1}},
43     {prim::kPrimReduceMean->name(), {1}},
44     {prim::kPrimSlice->name(), {1, 2}},
45     {prim::kPrimStridedSlice->name(), {1, 2, 3}},
46     {prim::kPrimOneHot->name(), {1}},
47     {prim::kPrimReduceFusion->name(), {1}},
48     {prim::kPrimConstantOfShape->name(), {0}},
49     {prim::kPrimGather->name(), {2}},
50     {prim::kPrimTupleGetItem->name(), {1}},
51     {prim::kPrimUnsortedSegmentSum->name(), {2}},
52     {prim::kPrimCumSum->name(), {1}}};
53   return op_idx_info_;
54 }
55 
IsConstInput(const AnfNodePtr & node)56 bool ValueDependOpUtils::IsConstInput(const AnfNodePtr &node) {
57   auto prim = GetCNodePrimitive(node);
58   if (prim != nullptr) {
59     const auto &op_index_info = GetOpIndexInfo();
60     auto iter = op_index_info.find(prim->name());
61     if (iter != op_index_info.end()) {
62       auto inputs = node->cast<CNodePtr>()->inputs();
63       for (const auto &i : iter->second) {
64         if (i + 1 < inputs.size() && inputs[i + 1] != nullptr) {
65           auto input_node = inputs[i + 1];
66           ValuePtr value = nullptr;
67           if (input_node->isa<ValueNode>()) {
68             auto value_node = input_node->cast<ValueNodePtr>();
69             value = value_node->value();
70           } else if (input_node->isa<Parameter>()) {
71             auto parameter_node = input_node->cast<ParameterPtr>();
72             value = parameter_node->abstract()->BuildValue();
73           }
74           if (value == nullptr) {
75             return false;
76           }
77           if (value->isa<ValueAny>()) {
78             return false;
79           }
80           auto tensor = value->cast<tensor::TensorPtr>();
81           if (tensor != nullptr && tensor->data().const_data() == nullptr) {
82             return false;
83           }
84         }
85       }
86     }
87   }
88   return true;
89 }
90 
AddConstInputToAttr(const CNodePtr & cnode,const HashSet<size_t> & input_idx)91 bool ValueDependOpUtils::AddConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx) {
92   auto primitive = GetCNodePrimitive(cnode);
93   MS_EXCEPTION_IF_NULL(primitive);
94   primitive = primitive->Clone();
95   MS_EXCEPTION_IF_NULL(primitive);
96 
97   const auto &op_name = primitive->name();
98   auto op_def = mindspore::ops::GetOpDef(op_name);
99   if (op_def == nullptr) {
100     MS_LOG(INFO) << op_name << " not found in op def.";
101     return false;
102   }
103   const auto &input_vec = op_def->args_;
104   auto inputs = cnode->inputs();
105   for (size_t i = 0; i < inputs.size() - 1; ++i) {
106     auto input_node = inputs[i + 1];
107     MS_EXCEPTION_IF_NULL(input_node);
108     if (input_idx.count(i) != 0) {
109       if (i >= input_vec.size()) {
110         MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_vec.size() << "]";
111         return false;
112       }
113       ValuePtr value = nullptr;
114       if (input_node->isa<ValueNode>()) {
115         auto value_node = input_node->cast<ValueNodePtr>();
116         value = value_node->value();
117       } else if (input_node->isa<Parameter>()) {
118         auto parameter_node = input_node->cast<ParameterPtr>();
119         value = parameter_node->abstract()->BuildValue();
120       }
121       if (value == nullptr) {
122         MS_LOG(DEBUG) << input_vec[i].arg_name_ << "'s Value is null.";
123         return false;
124       }
125       if (value->isa<ValueAny>()) {
126         MS_LOG(DEBUG) << input_vec[i].arg_name_ << "'s Value is ValueAny.";
127         return false;
128       }
129       if (!value->isa<tensor::Tensor>()) {
130         primitive->set_attr(input_vec[i].arg_name_, value);
131         continue;
132       }
133       auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(input_vec[i].arg_name_, value, primitive->name());
134       auto tensor = value->cast<tensor::TensorPtr>();
135       auto tensor_shape = tensor->shape_c();
136       if (tensor_shape.empty()) {
137         primitive->set_attr(input_vec[i].arg_name_, MakeValue(value_vector[0]));
138       } else {
139         primitive->set_attr(input_vec[i].arg_name_, MakeValue(value_vector));
140       }
141     }
142   }
143   cnode->set_input(0, std::make_shared<ValueNode>(primitive));
144   return true;
145 }
146 
147 }  // namespace mindspore::graphkernel
148