• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_
19 
20 #include <memory>
21 #include <tuple>
22 #include <vector>
23 #include <string>
24 #include "include/backend/optimizer/optimizer.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "include/backend/kernel_graph.h"
27 #include "backend/common/graph_kernel/graph_kernel_helper.h"
28 #include "backend/common/graph_kernel/inplace_assign_builder.h"
29 
30 namespace mindspore::graphkernel {
31 class AtomicAddChecker {
32  public:
33   AtomicAddChecker() = default;
34   virtual ~AtomicAddChecker() = default;
35   static std::shared_ptr<AtomicAddChecker> Init();
36 
37   bool Check(const AnfNodePtr &node);
GetAtomicAddInfo()38   std::vector<InplaceAssignerInfo> GetAtomicAddInfo() { return atomic_add_infos_; }
39 
40  protected:
SuitableForAtomicAdd(const AnfNodePtr &)41   virtual bool SuitableForAtomicAdd(const AnfNodePtr &) { return false; }
42   virtual bool FindCandidate(const AnfNodePtr &anf_node);
43   virtual bool CanActivateAtomicAdd(const AnfNodePtr &anf_node);
44   std::vector<InplaceAssignerInfo> atomic_add_infos_;
45   PrimitivePtr target_type_{prim::kPrimReduceSum};
46 };
47 
48 class TargetAtomicAddChecker : public AtomicAddChecker {
49  public:
50   explicit TargetAtomicAddChecker(const PrimitivePtr &target = prim::kPrimReduceSum) { target_type_ = target; }
51 
52  protected:
CanActivateAtomicAdd(const AnfNodePtr & anf_node)53   bool CanActivateAtomicAdd(const AnfNodePtr &anf_node) override { return FindCandidate(anf_node); }
54 };
55 
56 class AtomicAddCheckerGPU : public AtomicAddChecker {
57  public:
58   AtomicAddCheckerGPU() = default;
59   ~AtomicAddCheckerGPU() = default;
60 
61  protected:
62   bool SuitableForAtomicAdd(const AnfNodePtr &node) override;
63 };
64 
65 class AtomicAddCheckerAscend : public AtomicAddChecker {
66  public:
67   AtomicAddCheckerAscend() = default;
68   ~AtomicAddCheckerAscend() = default;
69 
70  protected:
71   bool SuitableForAtomicAdd(const AnfNodePtr &node) override;
72 };
73 
74 class AtomicCleanInserter : public InplaceAssignBuilder {
75  public:
InplaceAssignBuilder(name)76   explicit AtomicCleanInserter(const std::string &name = "atomic_clean") : InplaceAssignBuilder(name) {}
77   ~AtomicCleanInserter() override = default;
78   bool Run(const FuncGraphPtr &func_graph) override;
79 
80  protected:
81   void InsertAtomicClean(const FuncGraphPtr &main_graph, const AnfNodePtr &anf_node,
82                          const std::vector<InplaceAssignerInfo> &atomic_add_infos, const FuncGraphManagerPtr &mng);
SetTargetAttrs(const CNodePtr & cnode)83   void SetTargetAttrs(const CNodePtr &cnode) override {
84     SetNodeAttrSafely("enable_atomic_add", MakeValue(true), cnode);
85   }
86 };
87 using AtomicCleanInserterPtr = std::shared_ptr<AtomicCleanInserter>;
88 }  // namespace mindspore::graphkernel
89 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_H_
90