• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 #include "utils/hash_set.h"
27 #include "ir/anf.h"
28 #include "ir/func_graph.h"
29 #include "ir/primitive.h"
30 #include "include/backend/anf_runtime_algorithm.h"
31 #include "include/common/utils/anfalgo.h"
32 #include "include/backend/kernel_graph.h"
33 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
34 #include <nlohmann/json.hpp>
35 #include "backend/common/graph_kernel/model/lite_graph.h"
36 
37 namespace mindspore::graphkernel {
38 constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
39 constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
40 constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel";
41 constexpr auto kGraphKernelEstimateOps = "estimate_ops";
42 constexpr auto kGraphKernelGetNodeCalAmount = "estimate_calculation_amount";
43 constexpr auto kGraphKernelSplitFunc = "split_with_json";
44 constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
45 constexpr auto kGetGraphKernelExpanderOpList = "get_expander_op_list";
46 constexpr auto kJsonKeyMultiGraph = "multi_graph";
47 constexpr auto kJsonKeyGraphDesc = "graph_desc";
48 constexpr auto kJsonKeyGraphMode = "graph_mode";
49 
50 struct DataInfo {
51   std::string format{kOpFormat_DEFAULT};
52   ShapeVector shape{1};
53   TypePtr type{nullptr};
54 };
55 
56 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
57                                                       const std::vector<TypeId> &inputs_type,
58                                                       const std::vector<std::string> &output_formats,
59                                                       const std::vector<TypeId> &output_types);
60 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
61                                                       const std::vector<TypeId> &inputs_type,
62                                                       const std::vector<std::string> &output_formats,
63                                                       const std::vector<TypeId> &output_types,
64                                                       const kernel::Processor &processor);
65 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
66 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
67                    std::map<std::string, AnfNodePtr> *address_node_map);
68 bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
69 FuncGraphPtr JsonDescToAnf(const std::string &json_desc);
70 
71 std::string GetFormat(const AnfNodePtr &node);
72 TypePtr GetType(const AnfNodePtr &node);
73 ShapeVector GetShape(const AnfNodePtr &node);
74 ShapeVector GetDeviceShape(const AnfNodePtr &node);
75 std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
76 
77 CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
78                      bool use_fake_abstract = false);
79 void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
80 
81 ValueNodePtr CreateTensorValueNode(const DataInfo &info, void *value_ptr, size_t data_length);
82 AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx);
83 bool IsBufferStitchNode(const AnfNodePtr &node);
84 bool CheckDefaultFormat(const AnfNodePtr &node);
85 }  // namespace mindspore::graphkernel
86 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
87