1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_ 18 19 #include <map> 20 #include <set> 21 #include <vector> 22 #include <utility> 23 #include <string> 24 #include <memory> 25 #include <unordered_map> 26 #include "ir/anf.h" 27 #include "frontend/expander/bprop/bprop_irbuilder.h" 28 29 namespace mindspore { 30 namespace expander { 31 namespace bprop { 32 bool HasBpropExpander(const std::string &prim_name); 33 void ClearBpropOpGraphMap(); 34 using UserType = std::unordered_map<AnfNodePtr, std::vector<std::pair<std::weak_ptr<CNode>, int>>>; 35 struct UserMap { 36 UserType dout_user_; 37 UserType tuple_getitem_user_; 38 }; 39 class BpropExpander { 40 public: BpropExpander(CNodePtrList * outputs,UserMap * users)41 BpropExpander(CNodePtrList *outputs, UserMap *users) : outputs_(outputs), users_(users) {} 42 ~BpropExpander() = default; 43 bool Run(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values = {}); 44 static const mindspore::HashSet<size_t> &GetUnusedInputs(const string &op_name); 45 46 protected: 47 bool RunBprop(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values); 48 void PostProcess(const CNodePtr &cnode) const; 49 void DumpResult(const std::string &name) const; 50 NodePtrList input_nodes_; 51 NodePtrList output_nodes_; 52 CNodePtrList *outputs_{nullptr}; 53 UserMap *users_{nullptr}; 54 }; 55 56 bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph); 57 58 class OpEnvManager { 59 public: UsePyBprop(const std::string & name)60 static bool UsePyBprop(const std::string &name) { 61 static const auto op_set = GetEnvSet(); 62 return op_set.count(name) != 0; 63 } 64 65 private: GetEnvSet()66 static std::set<std::string> GetEnvSet() { 67 auto env = common::GetEnv("MS_DEV_USE_PY_BPROP"); 68 if (env.empty()) { 69 return {}; 70 } 71 std::set<std::string> op_set; 72 std::stringstream ss(env); 73 std::string token; 74 std::ostringstream oss; 75 while (std::getline(ss, token, ',')) { 76 if (op_set.insert(token).second) { 77 oss << "\"" << token << "\","; 78 } 79 } 80 MS_LOG(INFO) << "Env \"MS_DEV_USE_PY_BPROP\" set ops: " << oss.str(); 81 return op_set; 82 } 83 }; 84 #ifdef _MSC_VER 85 class WinBpropRegister { 86 public: 87 WinBpropRegister(); ~WinBpropRegister()88 ~WinBpropRegister() {} DoNothing()89 void DoNothing() const {} 90 }; 91 #endif 92 } // namespace bprop 93 } // namespace expander 94 95 using expander::bprop::BpropExpander; 96 #ifdef _MSC_VER 97 using expander::bprop::WinBpropRegister; 98 #endif 99 } // namespace mindspore 100 #endif // MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_ 101