• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_GRAPH_KERNEL_GRAPH_KERNEL_JSON_GENERATOR_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_GRAPH_KERNEL_GRAPH_KERNEL_JSON_GENERATOR_H_
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "nlohmann/json.hpp"
26 #include "kernel/oplib/opinfo.h"
27 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
28 #include "include/common/utils/convert_utils.h"
29 #include "mindspore/core/symbolic_shape/symbol_engine.h"
30 
31 namespace mindspore::graphkernel {
32 using kernel::OpAttrPtr;
33 using kernel::OpInfoPtr;
34 
35 // dump option
36 struct DumpOption {
37   bool is_before_select_kernel = false;
38   bool save_ptr_address = false;
39   bool extract_opinfo_from_anfnode = false;
40   bool get_target_info = false;
41   bool gen_kernel_name_only = false;
42 };
43 
44 class TargetInfoSetter {
45  public:
Set(nlohmann::json * kernel_info)46   static void Set(nlohmann::json *kernel_info) {
47     static std::unique_ptr<TargetInfoSetter> instance = nullptr;
48     if (instance == nullptr) {
49       instance = std::make_unique<TargetInfoSetter>();
50       instance->GetTargetInfo();
51     }
52     instance->SetTargetInfo(kernel_info);
53   }
54 
55  private:
56   void GetTargetInfo();
57   void SetTargetInfo(nlohmann::json *kernel_info) const;
58   nlohmann::json target_info_;
59   bool has_info_{true};
60 };
61 
62 class GraphKernelJsonGenerator {
63  public:
GraphKernelJsonGenerator()64   GraphKernelJsonGenerator() : cb_(Callback::Instance()) {}
GraphKernelJsonGenerator(DumpOption dump_option)65   explicit GraphKernelJsonGenerator(DumpOption dump_option)
66       : dump_option_(std::move(dump_option)), cb_(Callback::Instance()) {}
GraphKernelJsonGenerator(DumpOption dump_option,const CallbackPtr & cb)67   GraphKernelJsonGenerator(DumpOption dump_option, const CallbackPtr &cb)
68       : dump_option_(std::move(dump_option)), cb_(cb) {}
69   ~GraphKernelJsonGenerator() = default;
70 
71   bool CollectJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json);
72   bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
73                         const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json,
74                         const bool is_akg_cc = false);
75   bool CollectJson(const AnfNodePtr &anf_node);
76   bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
77                         const std::vector<AnfNodePtr> &output_list, const bool is_akg_cc = false);
78   bool CollectFusedJsonWithSingleKernel(const CNodePtr &c_node);
79 
kernel_name()80   std::string kernel_name() const { return kernel_name_; }
kernel_json()81   nlohmann::json kernel_json() const { return kernel_json_; }
kernel_json_str()82   std::string kernel_json_str() const { return kernel_json_.dump(); }
input_size_list()83   const std::vector<size_t> &input_size_list() const { return input_size_list_; }
output_size_list()84   const std::vector<size_t> &output_size_list() const { return output_size_list_; }
address_node_map()85   std::map<std::string, AnfNodePtr> address_node_map() const { return address_node_map_; }
symbol_engine()86   const SymbolEnginePtr &symbol_engine() const { return symbol_engine_; }
set_symbol_engine(const SymbolEnginePtr & symbol_engine)87   void set_symbol_engine(const SymbolEnginePtr &symbol_engine) { symbol_engine_ = symbol_engine; }
88 
89  private:
90   bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, nlohmann::json *node_json);
91   bool CreateInputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *inputs_json);
92   bool CreateOutputDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *outputs_json);
93   void GetAttrJson(const AnfNodePtr &anf_node, const std::vector<int64_t> &dyn_input_sizes, const OpAttrPtr &op_attr,
94                    nlohmann::json *attr_json, const ValuePtr &attr_value);
95   bool CreateAttrDescJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info, nlohmann::json *attrs_json);
96   void GenStitchJson(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map,
97                      nlohmann::json *kernel_json) const;
98   void GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *input_size,
99                  std::vector<size_t> *output_size) const;
100   bool GenSingleJsons(const std::vector<AnfNodePtr> &anf_nodes, std::map<AnfNodePtr, nlohmann::json> *node_json_map);
101   void UpdateTensorName(const std::vector<AnfNodePtr> &anf_nodes,
102                         std::map<AnfNodePtr, nlohmann::json> *node_json_map) const;
103   nlohmann::json CreateInputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
104                                   const std::map<AnfNodePtr, nlohmann::json> &node_json_map);
105   nlohmann::json CreateOutputsJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
106                                    const std::vector<AnfNodePtr> &output_list, const nlohmann::json &inputs_json,
107                                    const std::map<AnfNodePtr, nlohmann::json> &node_json_map);
108   size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
109   size_t GetOutputTensorIdxInc();
110   void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
111                      nlohmann::json *node_json) const;
112   std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
113                             const std::pair<size_t, size_t> &position) const;
114   void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json);
115   OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node) const;
116   void GenParallelJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
117                        const std::vector<AnfNodePtr> &output_list,
118                        const std::map<AnfNodePtr, nlohmann::json> &node_json_map, nlohmann::json *kernel_json) const;
119   bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, ShapeVector *input_shape,
120                            nlohmann::json *node_json) const;
121   size_t GetTensorSize(const nlohmann::json &node_json) const;
122   std::string GetProcessorByTarget() const;
123   size_t GenHashId(const std::string &info) const;
124   void GenKernelName(const FuncGraphPtr &fg, size_t hash_id, nlohmann::json *kernel_json);
125   void SaveShape(const AnfNodePtr &node, nlohmann::json *kernel_json, const ShapeVector &shape);
126   std::vector<std::string> QuerySymbolicShapeStr(const AnfNodePtr &node);
127 
128   DumpOption dump_option_;
129   std::string kernel_name_;
130   std::string all_ops_name_;
131   std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
132   size_t output_tensor_idx_{0};
133   nlohmann::json kernel_json_;
134   std::vector<size_t> input_size_list_;
135   std::vector<size_t> output_size_list_;
136   std::map<std::string, AnfNodePtr> address_node_map_;
137   SymbolEnginePtr symbol_engine_;
138   std::unordered_map<std::string, std::string> symbol_calc_exprs_;
139   bool is_basic_op_{false};
140   CallbackPtr cb_;
141 };
142 }  // namespace mindspore::graphkernel
143 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_GRAPH_KERNEL_GRAPH_KERNEL_JSON_GENERATOR_H_
144