• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <cstdio>
17 #include <unistd.h>
18 #include <sys/stat.h>
19 #include <iomanip>
20 #include <algorithm>
21 #include "mdgenerator.h"
22 
23 namespace MDGen {
24 std::string MDCodeGen::targetArchName = "";
25 
EmitCheckPtr(std::ofstream & outputFile,const std::string & emitName,const std::string & name,const std::string & ptrType) const26 void MDCodeGen::EmitCheckPtr(std::ofstream &outputFile, const std::string &emitName, const std::string &name,
27                              const std::string &ptrType) const
28 {
29     outputFile << "if(" << emitName << " == nullptr) {\n"
30                << "  maple::LogInfo::MapleLogger(maple::kLlErr) << \"" << ptrType << " allocation for " << name
31                << " failed.\" << std::endl;\n"
32                << "}\n"
33                << "DEBUG_ASSERT(" << emitName << ", \"" << ptrType << " allocation for " << name << " failed.\");\n"
34                << "\n";
35 }
36 
EmitFileHead(std::ofstream & outputFile,const std::string & headInfo) const37 void MDCodeGen::EmitFileHead(std::ofstream &outputFile, const std::string &headInfo) const
38 {
39     outputFile << "/* " << targetArchName << " " << headInfo << " definition : */\n";
40 }
41 
GetSpecificClass(const std::string & className)42 MDClass MDCodeGen::GetSpecificClass(const std::string &className)
43 {
44     unsigned int classIdx = curKeeper.GetStrInTable(className).idx;
45     CHECK_FATAL(classIdx != UINT_MAX, "Load Class Failed!");
46     return curKeeper.GetOneMDClass(classIdx);
47 }
48 
GetArchName()49 const std::string &SchedInfoGen::GetArchName()
50 {
51     MDClass archClass = GetSpecificClass("ArchitectureName");
52     const MDObject &archObj = archClass.GetOneMDObject(0);
53     auto *archStrEle = static_cast<const StringElement *>(archObj.GetOneMDElement(0));
54     return curKeeper.GetStrByIdx(archStrEle->GetContent());
55 }
56 
EmitArchDef()57 void SchedInfoGen::EmitArchDef()
58 {
59     MDClass parallelClass = GetSpecificClass("Parallelism");
60     CHECK_FATAL(parallelClass.GetMDObjectSize() > 0, "specific class failed, maybe illegal input");
61     const MDObject &paralleObj = parallelClass.GetOneMDObject(0);
62     auto *parallelEle = static_cast<const IntElement *>(paralleObj.GetOneMDElement(0));
63     outFile.open(GetOFileDir() + "/mplad_arch_define.def", std::ios::out);
64     EmitFileHead(outFile, "Architecture");
65     outFile << "SetMaxParallelism(" << parallelEle->GetContent() << ");\n";
66     outFile.close();
67 }
68 
EmitUnitIdDef()69 void SchedInfoGen::EmitUnitIdDef()
70 {
71     MDClass unitClass = GetSpecificClass("Unit");
72     outFile.open(GetOFileDir() + "/mplad_unit_id.def", std::ios::out);
73     CHECK_FATAL(outFile.is_open(), "Failed to open output file: %s/mplad_unit_id.def", GetOFileDir().c_str());
74     EmitFileHead(outFile, "function unit ID");
75     for (auto unitIdx : unitClass.GetchildObjNames()) {
76         outFile << "  " << curKeeper.GetStrByIdx(unitIdx) << ",\n";
77     }
78     outFile.close();
79 }
80 
EmitUnitNameDef()81 void SchedInfoGen::EmitUnitNameDef()
82 {
83     MDClass unitClass = GetSpecificClass("Unit");
84     outFile.open(GetOFileDir() + "/mplad_unit_name.def", std::ios::out);
85     CHECK_FATAL(outFile.is_open(), "Failed to open output file: %s/mplad_unit_name.def", GetOFileDir().c_str());
86     EmitFileHead(outFile, "function unit name");
87     for (auto unitIdx : unitClass.GetchildObjNames()) {
88         std::string unitPureName = curKeeper.GetStrByIdx(unitIdx);
89         std::string unitPrefix = "kUnitId";
90         if (unitPrefix.length() < unitPureName.length()) {
91             unitPureName = unitPureName.substr(unitPrefix.length());
92             outFile << "\"" << unitPureName << "\",\n";
93         }
94     }
95     outFile.close();
96 }
97 
EmitUnitDef()98 void SchedInfoGen::EmitUnitDef()
99 {
100     MDClass unitClass = GetSpecificClass("Unit");
101     outFile.open(GetOFileDir() + "/mplad_unit_define.def", std::ios::out);
102     CHECK_FATAL(outFile.is_open(), "Failed to open output file: %s/mplad_unit_define.def", GetOFileDir().c_str());
103     EmitFileHead(outFile, "function units ");
104     bool isUnitNumDef = false;
105     for (size_t i = 0; i < unitClass.GetMDObjectSize(); ++i) {
106         const MDObject &singleUnit = unitClass.GetOneMDObject(i);
107         if (singleUnit.GetOneMDElement(0)->GetRecDataTy() == MDElement::kEleDefaultTy) {
108             continue;
109         }
110         auto *curUnitTy = static_cast<const DefTyElement *>(singleUnit.GetOneMDElement(0));
111         std::string curUnitName = curKeeper.GetStrByIdx(singleUnit.GetIdx());
112         std::string emitUnitName = "instance" + curUnitName;
113         std::string unitPrefix = "Unit *" + emitUnitName + " = new Unit(";
114         if (!isUnitNumDef) {
115             outFile << "\n";
116             outFile << "const unsigned int kunitNum = 2;\n";
117             isUnitNumDef = true;
118         }
119         outFile << unitPrefix;
120         if (curUnitTy->GetContent() == curKeeper.GetStrInTable("Primary").idx) {
121             outFile << curUnitName << ");\n";
122         } else {
123             std::string unitTypeStr = "";
124             if (curUnitTy->GetContent() == curKeeper.GetStrInTable("And").idx) {
125                 unitTypeStr = "kUnitTypeAnd";
126             } else if (curUnitTy->GetContent() == curKeeper.GetStrInTable("Or").idx) {
127                 unitTypeStr = "kUnitTypeOr";
128             }
129             CHECK_FATAL(unitTypeStr.size() != 0, "Haven't support this kind of Unit yet");
130             outFile << unitTypeStr << ", " << curUnitName << ", kunitNum,\n";
131             outFile << std::setiosflags(std::ios::right) << std::setw(unitPrefix.length()) << std::setfill(' ') << " ";
132             unsigned int dependUnitsIndex = 1;
133             auto *dependUnitEle = static_cast<const VecElement *>(singleUnit.GetOneMDElement(dependUnitsIndex));
134             for (size_t k = 0; k < dependUnitEle->GetVecDataSize(); ++k) {
135                 auto *dependUnit = static_cast<DefObjElement *>(dependUnitEle->GetVecData()[k]);
136                 outFile << "instance" << curKeeper.GetStrByIdx(dependUnit->GetContent());
137                 if (k != dependUnitEle->GetVecDataSize() - 1) {
138                     outFile << ", ";
139                 }
140             }
141             outFile << ");\n";
142         }
143         EmitCheckPtr(outFile, emitUnitName, curUnitName, "Unit");
144     }
145     outFile.close();
146 }
147 
EmitLatencyDef()148 void SchedInfoGen::EmitLatencyDef()
149 {
150     MDClass resvClass = GetSpecificClass("Reservation");
151     outFile.open(GetOFileDir() + "/mplad_latency_type.def", std::ios::out);
152     CHECK_FATAL(outFile.is_open(), "Failed to open output file: %s/mplad_latency_type.def", GetOFileDir().c_str());
153     EmitFileHead(outFile, " latency type definition ");
154     for (auto resvIdx : resvClass.GetchildObjNames()) {
155         outFile << "  " << curKeeper.GetStrByIdx(resvIdx) << ",\n";
156     }
157     outFile.close();
158 }
159 
EmitResvDef()160 void SchedInfoGen::EmitResvDef()
161 {
162     MDClass resvClass = GetSpecificClass("Reservation");
163     outFile.open(GetOFileDir() + "/mplad_reservation_define.def", std::ios::out);
164     CHECK_FATAL(outFile.is_open(), "Failed to open output file: %s/mplad_reservation_define.def",
165                 GetOFileDir().c_str());
166     EmitFileHead(outFile, "reservations");
167     for (size_t i = 0; i < resvClass.GetMDObjectSize(); ++i) {
168         const MDObject &singleResv = resvClass.GetOneMDObject(i);
169         if (singleResv.GetOneMDElement(0)->GetRecDataTy() == MDElement::kEleDefaultTy) {
170             continue;
171         }
172         auto *curResvLatency = static_cast<const IntElement *>(singleResv.GetOneMDElement(0));
173         std::string curResvName = curKeeper.GetStrByIdx(singleResv.GetIdx());
174         std::string emitResvName = "resvInst" + curResvName;
175         std::string resvPrefix = "Reservation *" + emitResvName + " = new Reservation(";
176         outFile << resvPrefix << curResvName << ", " << curResvLatency->GetContent() << ", ";
177         if (singleResv.GetOneMDElement(1)->GetRecDataTy() == MDElement::kEleDefaultTy) {
178             outFile << "0);\n";
179         } else {
180             size_t dependUnitsIndex = 1;
181             auto *dependUnitEle = static_cast<const VecElement *>(singleResv.GetOneMDElement(dependUnitsIndex));
182             outFile << dependUnitEle->GetVecDataSize() << ",\n";
183             for (size_t k = 0; k < dependUnitEle->GetVecDataSize(); ++k) {
184                 auto *dependUnit = static_cast<DefObjElement *>(dependUnitEle->GetVecData()[k]);
185                 if (curKeeper.GetStrByIdx(dependUnit->GetContent()) != "nothing") {
186                     outFile << std::setiosflags(std::ios::right) << std::setw(resvPrefix.length()) << std::setfill(' ')
187                             << "GetUnitByUnitId(" << curKeeper.GetStrByIdx(dependUnit->GetContent()) << ")";
188                 } else {
189                     outFile << std::setiosflags(std::ios::right) << std::setw(resvPrefix.length()) << std::setfill(' ')
190                             << "nullptr";
191                 }
192                 if (k < dependUnitEle->GetVecDataSize() - 1) {
193                     outFile << ",\n";
194                 }
195             }
196             outFile << ");\n";
197         }
198         EmitCheckPtr(outFile, emitResvName, curResvName, "Reservation");
199     }
200     outFile.close();
201 }
202 
EmitBypassDef()203 void SchedInfoGen::EmitBypassDef()
204 {
205     MDClass bypassClass = GetSpecificClass("Bypass");
206     outFile.open(GetOFileDir() + "/mplad_bypass_define.def", std::ios::out);
207     for (size_t i = 0; i < bypassClass.GetMDObjectSize(); ++i) {
208         const MDObject &singleBypass = bypassClass.GetOneMDObject(i);
209         if (singleBypass.GetOneMDElement(0)->GetRecDataTy() == MDElement::kEleDefaultTy) {
210             continue;
211         }
212         constexpr size_t fromVecIndex = 1;
213         constexpr size_t toVecIndex = 2;
214         constexpr size_t curBpTyIndex = 3;
215         auto *bpTyEle = singleBypass.GetOneMDElement(curBpTyIndex);
216         std::string curBypassTy =
217             (bpTyEle->GetRecDataTy() == MDElement::kEleDefaultTy) ? "" : curKeeper.GetStrByIdx(bpTyEle->GetContent());
218         transform(curBypassTy.begin(), curBypassTy.end(), curBypassTy.begin(), ::toupper);
219 
220         CHECK_FATAL(singleBypass.GetOneMDElement(0)->GetRecDataTy() == MDElement::ElementTy::kEleIntTy,
221                     "Bypass illegal");
222         CHECK_FATAL(singleBypass.GetOneMDElement(fromVecIndex)->GetRecDataTy() == MDElement::ElementTy::kEleVecTy,
223                     "Bypass illegal");
224         CHECK_FATAL(singleBypass.GetOneMDElement(toVecIndex)->GetRecDataTy() == MDElement::ElementTy::kEleVecTy,
225                     "Bypass illegal");
226 
227         unsigned int bypassNum = static_cast<const IntElement *>(singleBypass.GetOneMDElement(0))->GetContent();
228         auto *fromVec = static_cast<const VecElement *>(singleBypass.GetOneMDElement(fromVecIndex));
229         auto *toVec = static_cast<const VecElement *>(singleBypass.GetOneMDElement(toVecIndex));
230         for (auto itTo : toVec->GetVecData()) {
231             for (auto itFrom : fromVec->GetVecData()) {
232                 auto *fromResv = static_cast<DefObjElement *>(itFrom);
233                 auto *toResv = static_cast<DefObjElement *>(itTo);
234                 outFile << "ADD" << curBypassTy << "BYPASS(" << curKeeper.GetStrByIdx(fromResv->GetContent()) << ", "
235                         << curKeeper.GetStrByIdx(toResv->GetContent()) << ", " << bypassNum << ");\n";
236             }
237         }
238     }
239     outFile.close();
240 }
241 
Run()242 void SchedInfoGen::Run()
243 {
244     SetTargetArchName(GetArchName());
245     EmitArchDef();
246     EmitResvDef();
247     EmitBypassDef();
248     EmitUnitDef();
249     EmitUnitNameDef();
250     EmitLatencyDef();
251     EmitUnitIdDef();
252 }
253 } /* namespace MDGen */
254