1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H 18 #define MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H 19 20 #include <string> 21 #include <vector> 22 #include "ir/anf.h" 23 #include "mindspore/ccsrc/include/common/utils/anfalgo.h" 24 25 namespace mindspore { 26 class AotKernelData { 27 public: 28 AotKernelData() = default; 29 virtual ~AotKernelData() = default; 30 }; 31 32 class AotExtra { 33 public: 34 AotExtra() = default; 35 virtual ~AotExtra() = default; 36 virtual bool HasAttr(std::string name) = 0; 37 38 template <typename T> Attr(std::string name)39 inline T Attr(std::string name) const { 40 MS_EXCEPTION_IF_CHECK_FAIL(name.length() > 0, "The input name is an empty string"); 41 return T(); 42 } 43 SetWorkSpace(const std::vector<size_t> & workspace)44 void SetWorkSpace(const std::vector<size_t> &workspace) { workspace_ = workspace; } WorkSpace()45 const std::vector<size_t> &WorkSpace() const { return workspace_; } 46 SetKernelData(AotKernelData * kernel_data)47 void SetKernelData(AotKernelData *kernel_data) { kernel_data_ = kernel_data; } KernelData()48 const AotKernelData *KernelData() const { return kernel_data_; } 49 DestructKernelData()50 void DestructKernelData() { 51 delete kernel_data_; 52 kernel_data_ = nullptr; 53 } 54 55 private: 56 virtual bool GetAttrBool(std::string name) = 0; 57 virtual int64_t GetAttrInt(std::string name) = 0; 58 virtual float GetAttrFloat(std::string name) = 0; 59 virtual std::string GetAttrStr(std::string name) = 0; 60 61 virtual std::vector<int64_t> GetAttrIntVec(std::string name) = 0; 62 virtual std::vector<float> GetAttrFloatVec(std::string name) = 0; 63 virtual std::vector<std::vector<int64_t>> GetAttrInt2DVec(std::string name) = 0; 64 virtual std::vector<std::vector<float>> GetAttrFloat2DVec(std::string name) = 0; 65 std::vector<size_t> workspace_; 66 67 AotKernelData *kernel_data_{nullptr}; 68 }; 69 70 class AotExtraImpl : public AotExtra { 71 public: AotExtraImpl()72 AotExtraImpl() : prim_(nullptr) {} 73 virtual ~AotExtraImpl() = default; SetKernelPrim(const PrimitivePtr & prim)74 void SetKernelPrim(const PrimitivePtr &prim) { prim_ = prim; } HasAttr(std::string name)75 bool HasAttr(std::string name) final { return prim_ != nullptr && prim_->HasAttr(name); } 76 77 private: GetAttrBool(std::string name)78 bool GetAttrBool(std::string name) { 79 MS_EXCEPTION_IF_NULL(prim_); 80 auto value = prim_->GetAttr(name); 81 if (value == nullptr) { 82 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 83 } 84 return GetValue<bool>(value); 85 } GetAttrInt(std::string name)86 int64_t GetAttrInt(std::string name) { 87 MS_EXCEPTION_IF_NULL(prim_); 88 auto value = prim_->GetAttr(name); 89 if (value == nullptr) { 90 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 91 } 92 return GetValue<int64_t>(value); 93 } GetAttrFloat(std::string name)94 float GetAttrFloat(std::string name) { 95 MS_EXCEPTION_IF_NULL(prim_); 96 auto value = prim_->GetAttr(name); 97 if (value == nullptr) { 98 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 99 } 100 return GetValue<float>(value); 101 } GetAttrStr(std::string name)102 std::string GetAttrStr(std::string name) { 103 MS_EXCEPTION_IF_NULL(prim_); 104 auto value = prim_->GetAttr(name); 105 if (value == nullptr) { 106 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 107 } 108 return GetValue<std::string>(value); 109 } 110 GetAttrIntVec(std::string name)111 std::vector<int64_t> GetAttrIntVec(std::string name) { 112 MS_EXCEPTION_IF_NULL(prim_); 113 auto value = prim_->GetAttr(name); 114 if (value == nullptr) { 115 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 116 } 117 return GetValue<std::vector<int64_t>>(value); 118 } GetAttrFloatVec(std::string name)119 std::vector<float> GetAttrFloatVec(std::string name) { 120 MS_EXCEPTION_IF_NULL(prim_); 121 auto value = prim_->GetAttr(name); 122 if (value == nullptr) { 123 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 124 } 125 return GetValue<std::vector<float>>(value); 126 } GetAttrInt2DVec(std::string name)127 std::vector<std::vector<int64_t>> GetAttrInt2DVec(std::string name) { 128 MS_EXCEPTION_IF_NULL(prim_); 129 auto value = prim_->GetAttr(name); 130 if (value == nullptr) { 131 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 132 } 133 return GetValue<std::vector<std::vector<int64_t>>>(value); 134 } GetAttrFloat2DVec(std::string name)135 std::vector<std::vector<float>> GetAttrFloat2DVec(std::string name) { 136 MS_EXCEPTION_IF_NULL(prim_); 137 auto value = prim_->GetAttr(name); 138 if (value == nullptr) { 139 MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! "; 140 } 141 return GetValue<std::vector<std::vector<float>>>(value); 142 } 143 PrimitivePtr prim_; 144 }; 145 } // namespace mindspore 146 #endif // MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H 147