1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_ 19 20 #include <memory> 21 #include <tuple> 22 #include <utility> 23 #include <vector> 24 #include <string> 25 #include "backend/optimizer/common/optimizer.h" 26 #include "backend/session/kernel_graph.h" 27 28 namespace mindspore { 29 namespace opt { 30 struct AtomicAddInfo { 31 CNodePtr atomic_add_node{nullptr}; 32 size_t reduce_real_output_index{0}; 33 size_t real_output_num{0}; 34 }; 35 36 class AtomicAddChecker { 37 public: 38 AtomicAddChecker() = default; 39 virtual ~AtomicAddChecker() = default; 40 static std::shared_ptr<AtomicAddChecker> Init(); 41 42 bool Check(const AnfNodePtr &node); GetAtomicAddInfo()43 AtomicAddInfo GetAtomicAddInfo() { return atomic_add_info_; } 44 45 protected: SuitableForAtomicAdd(const AnfNodePtr & node)46 virtual bool SuitableForAtomicAdd(const AnfNodePtr &node) { return false; } 47 virtual bool FindCandidate(const AnfNodePtr &anf_node); 48 virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node); 49 AtomicAddInfo atomic_add_info_; 50 PrimitivePtr target_type_{prim::kPrimReduceSum}; 51 }; 52 53 class AtomicAddCheckerGPU : public AtomicAddChecker { 54 public: 55 AtomicAddCheckerGPU() = default; 56 ~AtomicAddCheckerGPU() = default; 57 58 protected: 59 bool SuitableForAtomicAdd(const AnfNodePtr &node) override; 60 }; 61 62 class AtomicAddCheckerAscend : public AtomicAddChecker { 63 public: 64 AtomicAddCheckerAscend() = default; 65 ~AtomicAddCheckerAscend() = default; 66 67 protected: 68 bool SuitableForAtomicAdd(const AnfNodePtr &node) override; 69 }; 70 71 class AtomicCleanInsertter : public Pass { 72 public: Pass(name)73 explicit AtomicCleanInsertter(const std::string &name = "atomic_clean") : Pass(name) {} 74 ~AtomicCleanInsertter() override = default; 75 bool Run(const FuncGraphPtr &func_graph) override; 76 77 protected: 78 virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, 79 bool bypass = true); 80 virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); 81 virtual CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); 82 void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, 83 const AnfNodePtr &user_node, int index) const; 84 void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); 85 CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) const; 86 void CorrectAbstract(const AnfNodePtr &composite_node) const; 87 void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); 88 void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, 89 const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node, 90 const FuncGraphManagerPtr &mng); 91 void UpdateAtomicAddInfo(const AtomicAddInfo &info); 92 CNodePtr atomic_add_node_{nullptr}; 93 size_t reduce_real_output_index_{0}; 94 size_t real_output_num_{0}; 95 96 private: 97 std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph, 98 const AnfNodePtr &composite_node, 99 const FuncGraphManagerPtr &mng, 100 bool correct_index) const; 101 bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node, 102 const FuncGraphManagerPtr &mng); 103 }; 104 using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>; 105 } // namespace opt 106 } // namespace mindspore 107 108 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_ 109