1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_TBE_JSON_CREATOR_H_ 18 #define MINDSPORE_TBE_JSON_CREATOR_H_ 19 #include <string> 20 #include <unordered_map> 21 #include <memory> 22 #include <map> 23 #include <utility> 24 #include <vector> 25 #include <nlohmann/json.hpp> 26 #include "ir/dtype.h" 27 #include "backend/kernel_compiler/kernel.h" 28 #include "backend/kernel_compiler/kernel_fusion.h" 29 #include "backend/kernel_compiler/oplib/oplib.h" 30 #include "backend/kernel_compiler/tbe/tbe_adapter.h" 31 32 namespace mindspore::kernel { 33 enum ATTR_DTYPE { 34 ATTR_INT8 = 0, 35 ATTR_UINT8 = 1, 36 ATTR_INT16 = 2, 37 ATTR_UINT16 = 3, 38 ATTR_INT32 = 4, 39 ATTR_UINT32 = 5, 40 ATTR_INT64 = 6, 41 ATTR_UINT64 = 7, 42 ATTR_FLOAT32 = 8, 43 ATTR_DOUBLE = 9, 44 ATTR_BOOL = 10, 45 ATTR_STR = 11, 46 ATTR_LIST_INT8 = 12, 47 ATTR_LIST_UINT8 = 13, 48 ATTR_LIST_INT16 = 14, 49 ATTR_LIST_UINT16 = 15, 50 ATTR_LIST_INT32 = 16, 51 ATTR_LIST_UINT32 = 17, 52 ATTR_LIST_INT64 = 18, 53 ATTR_LIST_UINT64 = 19, 54 ATTR_LIST_FLOAT32 = 20, 55 ATTR_LIST_DOUBLE = 21, 56 ATTR_LIST_BOOL = 22, 57 ATTR_LIST_STR = 23, 58 ATTR_LIST_LIST_INT64 = 24, 59 ATTR_LIST_LIST_FLOAT = 25, 60 61 // illegal type which can't be fused 62 ATTR_MAX, 63 }; 64 65 class TbeJsonCreator { 66 public: 67 TbeJsonCreator() = default; 68 virtual ~TbeJsonCreator() = default; GenJson(const AnfNodePtr & anf_node,nlohmann::json * kernel_json)69 virtual bool GenJson(const AnfNodePtr &anf_node, nlohmann::json *kernel_json) { return false; } GenJson(const FusionScopeInfo & fusion_scope_info,nlohmann::json * fusion_json)70 virtual bool GenJson(const FusionScopeInfo &fusion_scope_info, nlohmann::json *fusion_json) { return false; } GetJsonName()71 std::string GetJsonName() { return json_name_; } GetJsonHash()72 size_t GetJsonHash() { return json_hash_; } 73 74 protected: 75 bool GenComputeJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json); GenInputsJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)76 virtual bool GenInputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) { return false; } GenOutputsJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)77 virtual bool GenOutputsJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) { return false; } 78 void GenOutputDataDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json); 79 void GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json); GenOtherJson(const AnfNodePtr & anf_node,nlohmann::json * compute_json)80 virtual void GenOtherJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json) {} 81 void GenAttrsDescJson(const AnfNodePtr &anf_node, nlohmann::json *compute_json); 82 void GenAttrsJson(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr, nlohmann::json *attrs_json); 83 bool AttrsJsonPreProcessing(const AnfNodePtr &anf_node, std::vector<OpAttrPtr> *attrs_ptr, 84 nlohmann::json *attrs_json); 85 virtual bool AttrsJsonPostProcessing(const AnfNodePtr &anf_node, const OpInfoPtr &op_info_ptr, 86 nlohmann::json *attrs_json); 87 virtual void GenDescJson(const AnfNodePtr &anf_node, size_t node_out_idx, size_t desc_output_idx, 88 nlohmann::json *output_desc); 89 void GenDesJsonCommon(nlohmann::json *output_desc); 90 void GenInputConstValue(const AnfNodePtr &anf_node, size_t real_input_index, nlohmann::json *input_desc); 91 size_t GenJsonHash(nlohmann::json tbe_json); 92 void DeleteDescName(nlohmann::json *desc_json); 93 void AddOpNameForComputeNode(nlohmann::json *kernel_json); 94 void GenFusionOpName(nlohmann::json *kernel_json, std::string prefix = ""); 95 96 private: 97 std::string json_name_; 98 size_t json_hash_; 99 }; 100 101 } // namespace mindspore::kernel 102 #endif // MINDSPORE_TBE_JSON_CREATOR_H_ 103