• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H
18 #define MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H
19 
20 #include <string>
21 #include <vector>
22 #include "ir/anf.h"
23 #include "mindspore/ccsrc/include/common/utils/anfalgo.h"
24 
25 namespace mindspore {
26 class AotKernelData {
27  public:
28   AotKernelData() = default;
29   virtual ~AotKernelData() = default;
30 };
31 
32 class AotExtra {
33  public:
34   AotExtra() = default;
35   virtual ~AotExtra() = default;
36   virtual bool HasAttr(std::string name) = 0;
37 
38   template <typename T>
Attr(std::string name)39   inline T Attr(std::string name) const {
40     MS_EXCEPTION_IF_CHECK_FAIL(name.length() > 0, "The input name is an empty string");
41     return T();
42   }
43 
SetWorkSpace(const std::vector<size_t> & workspace)44   void SetWorkSpace(const std::vector<size_t> &workspace) { workspace_ = workspace; }
WorkSpace()45   const std::vector<size_t> &WorkSpace() const { return workspace_; }
46 
SetKernelData(AotKernelData * kernel_data)47   void SetKernelData(AotKernelData *kernel_data) { kernel_data_ = kernel_data; }
KernelData()48   const AotKernelData *KernelData() const { return kernel_data_; }
49 
DestructKernelData()50   void DestructKernelData() {
51     delete kernel_data_;
52     kernel_data_ = nullptr;
53   }
54 
55  private:
56   virtual bool GetAttrBool(std::string name) = 0;
57   virtual int64_t GetAttrInt(std::string name) = 0;
58   virtual float GetAttrFloat(std::string name) = 0;
59   virtual std::string GetAttrStr(std::string name) = 0;
60 
61   virtual std::vector<int64_t> GetAttrIntVec(std::string name) = 0;
62   virtual std::vector<float> GetAttrFloatVec(std::string name) = 0;
63   virtual std::vector<std::vector<int64_t>> GetAttrInt2DVec(std::string name) = 0;
64   virtual std::vector<std::vector<float>> GetAttrFloat2DVec(std::string name) = 0;
65   std::vector<size_t> workspace_;
66 
67   AotKernelData *kernel_data_{nullptr};
68 };
69 
70 class AotExtraImpl : public AotExtra {
71  public:
AotExtraImpl()72   AotExtraImpl() : prim_(nullptr) {}
73   virtual ~AotExtraImpl() = default;
SetKernelPrim(const PrimitivePtr & prim)74   void SetKernelPrim(const PrimitivePtr &prim) { prim_ = prim; }
HasAttr(std::string name)75   bool HasAttr(std::string name) final { return prim_ != nullptr && prim_->HasAttr(name); }
76 
77  private:
GetAttrBool(std::string name)78   bool GetAttrBool(std::string name) {
79     MS_EXCEPTION_IF_NULL(prim_);
80     auto value = prim_->GetAttr(name);
81     if (value == nullptr) {
82       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
83     }
84     return GetValue<bool>(value);
85   }
GetAttrInt(std::string name)86   int64_t GetAttrInt(std::string name) {
87     MS_EXCEPTION_IF_NULL(prim_);
88     auto value = prim_->GetAttr(name);
89     if (value == nullptr) {
90       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
91     }
92     return GetValue<int64_t>(value);
93   }
GetAttrFloat(std::string name)94   float GetAttrFloat(std::string name) {
95     MS_EXCEPTION_IF_NULL(prim_);
96     auto value = prim_->GetAttr(name);
97     if (value == nullptr) {
98       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
99     }
100     return GetValue<float>(value);
101   }
GetAttrStr(std::string name)102   std::string GetAttrStr(std::string name) {
103     MS_EXCEPTION_IF_NULL(prim_);
104     auto value = prim_->GetAttr(name);
105     if (value == nullptr) {
106       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
107     }
108     return GetValue<std::string>(value);
109   }
110 
GetAttrIntVec(std::string name)111   std::vector<int64_t> GetAttrIntVec(std::string name) {
112     MS_EXCEPTION_IF_NULL(prim_);
113     auto value = prim_->GetAttr(name);
114     if (value == nullptr) {
115       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
116     }
117     return GetValue<std::vector<int64_t>>(value);
118   }
GetAttrFloatVec(std::string name)119   std::vector<float> GetAttrFloatVec(std::string name) {
120     MS_EXCEPTION_IF_NULL(prim_);
121     auto value = prim_->GetAttr(name);
122     if (value == nullptr) {
123       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
124     }
125     return GetValue<std::vector<float>>(value);
126   }
GetAttrInt2DVec(std::string name)127   std::vector<std::vector<int64_t>> GetAttrInt2DVec(std::string name) {
128     MS_EXCEPTION_IF_NULL(prim_);
129     auto value = prim_->GetAttr(name);
130     if (value == nullptr) {
131       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
132     }
133     return GetValue<std::vector<std::vector<int64_t>>>(value);
134   }
GetAttrFloat2DVec(std::string name)135   std::vector<std::vector<float>> GetAttrFloat2DVec(std::string name) {
136     MS_EXCEPTION_IF_NULL(prim_);
137     auto value = prim_->GetAttr(name);
138     if (value == nullptr) {
139       MS_LOG(EXCEPTION) << "For '" << prim_->ToString() << ", there is no attribute called " << name << "! ";
140     }
141     return GetValue<std::vector<std::vector<float>>>(value);
142   }
143   PrimitivePtr prim_;
144 };
145 }  // namespace mindspore
146 #endif  // MINDSPORE_CCSRC_UTILS_CUSTOM_AOT_EXTRA_H
147