• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_TBE_JSON_CREATOR_H_
18 #define MINDSPORE_TBE_JSON_CREATOR_H_
19 #include <string>
20 #include <unordered_map>
21 #include <memory>
22 #include <map>
23 #include <utility>
24 #include <vector>
25 #include <nlohmann/json.hpp>
26 #include "ir/dtype.h"
27 #include "backend/kernel_compiler/kernel.h"
28 #include "backend/kernel_compiler/kernel_fusion.h"
29 #include "backend/kernel_compiler/oplib/oplib.h"
30 #include "backend/kernel_compiler/tbe/tbe_adapter.h"
31 
32 namespace mindspore::kernel {
33 enum ATTR_DTYPE {
34   ATTR_INT8 = 0,
35   ATTR_UINT8 = 1,
36   ATTR_INT16 = 2,
37   ATTR_UINT16 = 3,
38   ATTR_INT32 = 4,
39   ATTR_UINT32 = 5,
40   ATTR_INT64 = 6,
41   ATTR_UINT64 = 7,
42   ATTR_FLOAT32 = 8,
43   ATTR_DOUBLE = 9,
44   ATTR_BOOL = 10,
45   ATTR_STR = 11,
46   ATTR_LIST_INT8 = 12,
47   ATTR_LIST_UINT8 = 13,
48   ATTR_LIST_INT16 = 14,
49   ATTR_LIST_UINT16 = 15,
50   ATTR_LIST_INT32 = 16,
51   ATTR_LIST_UINT32 = 17,
52   ATTR_LIST_INT64 = 18,
53   ATTR_LIST_UINT64 = 19,
54   ATTR_LIST_FLOAT32 = 20,
55   ATTR_LIST_DOUBLE = 21,
56   ATTR_LIST_BOOL = 22,
57   ATTR_LIST_STR = 23,
58   ATTR_LIST_LIST_INT64 = 24,
59   ATTR_LIST_LIST_FLOAT = 25,
60 
61   // illegal type which can't be fused
62   ATTR_MAX,
63 };
64 
65 class TbeJsonCreator {
66  public:
67   TbeJsonCreator() = default;
68   virtual ~TbeJsonCreator() = default;
GenJson(const AnfNodePtr & anf_node,nlohmann::json * kernel_json)69   virtual bool GenJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) { return false; }
GenJson(const FusionScopeInfo & fusion_scope_info,nlohmann::json * fusion_json)70   virtual bool GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) { return false; }
GetJsonName()71   std::string GetJsonName() { return json_name_; }
GetJsonHash()72   size_t GetJsonHash() { return json_hash_; }
73 
74  protected:
75   bool GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json);
GenInputsJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)76   virtual bool GenInputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) { return false; }
GenOutputsJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)77   virtual bool GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) { return false; }
78   void GenOutputDataDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json);
79   void GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json);
GenOtherJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)80   virtual void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {}
81   void GenAttrsDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json);
82   void GenAttrsJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr, nlohmann::json *attrs_json);
83   bool AttrsJsonPreProcessing(const AnfNodePtr &anf_node, std::vector<OpAttrPtr> *attrs_ptr,
84                               nlohmann::json *attrs_json);
85   virtual bool AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr,
86                                        nlohmann::json *attrs_json);
87   virtual void GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx,
88                            nlohmann::json *output_desc);
89   void GenDesJsonCommon(nlohmann::json *output_desc);
90   void GenInputConstValue(const AnfNodePtr &anf_node, size_t real_input_index, nlohmann::json *input_desc);
91   size_t GenJsonHash(nlohmann::json tbe_json);
92   void DeleteDescName(nlohmann::json *desc_json);
93   void AddOpNameForComputeNode(nlohmann::json *kernel_json);
94   void GenFusionOpName(nlohmann::json *kernel_json, std::string prefix = "");
95 
96  private:
97   std::string json_name_;
98   size_t json_hash_;
99 };
100 
101 }  // namespace mindspore::kernel
102 #endif  // MINDSPORE_TBE_JSON_CREATOR_H_
103