• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <utility>
24 #include "utils/hash_map.h"
25 #include "utils/hash_set.h"
26 #include "base/base.h"
27 #include "ir/anf.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
33 std::string ParameterName(const AnfNodePtr &node_ptr);
34 
35 bool ParameterRequireGrad(const AnfNodePtr &node_ptr);
36 
37 size_t GetLengthOfDataType(const TypePtr &type);
38 
39 std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node);
40 
41 std::string ExtractInputParameterNameByNode(const CNodePtr &node);
42 
43 std::vector<size_t> ExtractInputElementLength(const CNodePtr &node, std::vector<AnfNodePtr> node_inputs);
44 
45 std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node);
46 
47 std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node);
48 
49 std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
50 
51 bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
52 
53 bool FindReshape(const CNodePtr &cnode, mindspore::HashSet<std::string> *op_cache);
54 
55 bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, bool *is_prev_param,
56                                  int64_t *out_index, size_t curr_depth);
57 
58 void FindReshapeNextNodeStraCosts(const CNodePtr &cnode,
59                                   std::vector<std::pair<OperatorInfoPtr, int64_t>> *next_ops_index,
60                                   bool *is_next_reshape, size_t curr_depth);
61 
62 void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim);
63 
64 Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input);
65 
66 template <typename T>
TransVectorToValueSequeue(const std::vector<int64_t> & input)67 std::shared_ptr<typename std::enable_if<std::is_base_of<ValueSequeue, T>::value, T>::type> TransVectorToValueSequeue(
68   const std::vector<int64_t> &input) {
69   std::vector<ValuePtr> elements;
70   for (auto dim : input) {
71     ValuePtr value_dim = MakeValue<int64_t>(dim);
72     elements.push_back(value_dim);
73   }
74   std::shared_ptr<T> seq_value = std::make_shared<T>(elements);
75   return seq_value;
76 }
77 
78 const AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index);
79 }  // namespace parallel
80 }  // namespace mindspore
81 
82 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_
83