• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_
17 #define MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <utility>
23 #include <string>
24 #include <memory>
25 #include <unordered_map>
26 #include "ir/anf.h"
27 #include "frontend/expander/bprop/bprop_irbuilder.h"
28 
29 namespace mindspore {
30 namespace expander {
31 namespace bprop {
32 bool HasBpropExpander(const std::string &prim_name);
33 void ClearBpropOpGraphMap();
34 using UserType = std::unordered_map<AnfNodePtr, std::vector<std::pair<std::weak_ptr<CNode>, int>>>;
35 struct UserMap {
36   UserType dout_user_;
37   UserType tuple_getitem_user_;
38 };
39 class BpropExpander {
40  public:
BpropExpander(CNodePtrList * outputs,UserMap * users)41   BpropExpander(CNodePtrList *outputs, UserMap *users) : outputs_(outputs), users_(users) {}
42   ~BpropExpander() = default;
43   bool Run(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values = {});
44   static const mindspore::HashSet<size_t> &GetUnusedInputs(const string &op_name);
45 
46  protected:
47   bool RunBprop(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values);
48   void PostProcess(const CNodePtr &cnode) const;
49   void DumpResult(const std::string &name) const;
50   NodePtrList input_nodes_;
51   NodePtrList output_nodes_;
52   CNodePtrList *outputs_{nullptr};
53   UserMap *users_{nullptr};
54 };
55 
56 bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph);
57 
58 class OpEnvManager {
59  public:
UsePyBprop(const std::string & name)60   static bool UsePyBprop(const std::string &name) {
61     static const auto op_set = GetEnvSet();
62     return op_set.count(name) != 0;
63   }
64 
65  private:
GetEnvSet()66   static std::set<std::string> GetEnvSet() {
67     auto env = common::GetEnv("MS_DEV_USE_PY_BPROP");
68     if (env.empty()) {
69       return {};
70     }
71     std::set<std::string> op_set;
72     std::stringstream ss(env);
73     std::string token;
74     std::ostringstream oss;
75     while (std::getline(ss, token, ',')) {
76       if (op_set.insert(token).second) {
77         oss << "\"" << token << "\",";
78       }
79     }
80     MS_LOG(INFO) << "Env \"MS_DEV_USE_PY_BPROP\" set ops: " << oss.str();
81     return op_set;
82   }
83 };
84 #ifdef _MSC_VER
85 class WinBpropRegister {
86  public:
87   WinBpropRegister();
~WinBpropRegister()88   ~WinBpropRegister() {}
DoNothing()89   void DoNothing() const {}
90 };
91 #endif
92 }  // namespace bprop
93 }  // namespace expander
94 
95 using expander::bprop::BpropExpander;
96 #ifdef _MSC_VER
97 using expander::bprop::WinBpropRegister;
98 #endif
99 }  // namespace mindspore
100 #endif  // MINDSPORE_CCSRC_FRONTEND_EXPANDER_BPROP_BPROP_H_
101